test_factory.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import pytest
  3. import embedchain
  4. import embedchain.embedder.gpt4all
  5. import embedchain.embedder.huggingface
  6. import embedchain.embedder.openai
  7. import embedchain.embedder.vertexai
  8. import embedchain.llm.anthropic
  9. import embedchain.llm.openai
  10. import embedchain.vectordb.chroma
  11. import embedchain.vectordb.elasticsearch
  12. import embedchain.vectordb.opensearch
  13. from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
  14. class TestFactories:
  15. @pytest.mark.parametrize(
  16. "provider_name, config_data, expected_class",
  17. [
  18. ("openai", {}, embedchain.llm.openai.OpenAILlm),
  19. ("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
  20. ],
  21. )
  22. def test_llm_factory_create(self, provider_name, config_data, expected_class):
  23. os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
  24. os.environ["OPENAI_API_KEY"] = "test_api_key"
  25. llm_instance = LlmFactory.create(provider_name, config_data)
  26. assert isinstance(llm_instance, expected_class)
  27. @pytest.mark.parametrize(
  28. "provider_name, config_data, expected_class",
  29. [
  30. ("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
  31. (
  32. "huggingface",
  33. {"model": "sentence-transformers/all-mpnet-base-v2", "vector_dimension": 768},
  34. embedchain.embedder.huggingface.HuggingFaceEmbedder,
  35. ),
  36. ("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAIEmbedder),
  37. ("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
  38. ],
  39. )
  40. def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
  41. mocker.patch("embedchain.embedder.vertexai.VertexAIEmbedder", autospec=True)
  42. embedder_instance = EmbedderFactory.create(provider_name, config_data)
  43. assert isinstance(embedder_instance, expected_class)
  44. @pytest.mark.parametrize(
  45. "provider_name, config_data, expected_class",
  46. [
  47. ("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
  48. (
  49. "opensearch",
  50. {"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
  51. embedchain.vectordb.opensearch.OpenSearchDB,
  52. ),
  53. ("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
  54. ],
  55. )
  56. def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
  57. mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
  58. vectordb_instance = VectorDBFactory.create(provider_name, config_data)
  59. assert isinstance(vectordb_instance, expected_class)