test_factory.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import pytest
  2. import embedchain
  3. import embedchain.embedder.gpt4all
  4. import embedchain.embedder.huggingface
  5. import embedchain.embedder.openai
  6. import embedchain.embedder.vertexai
  7. import embedchain.llm.anthropic
  8. import embedchain.llm.openai
  9. import embedchain.vectordb.chroma
  10. import embedchain.vectordb.elasticsearch
  11. import embedchain.vectordb.opensearch
  12. from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
  13. class TestFactories:
  14. @pytest.mark.parametrize(
  15. "provider_name, config_data, expected_class",
  16. [
  17. ("openai", {}, embedchain.llm.openai.OpenAILlm),
  18. ("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
  19. ],
  20. )
  21. def test_llm_factory_create(self, provider_name, config_data, expected_class):
  22. llm_instance = LlmFactory.create(provider_name, config_data)
  23. assert isinstance(llm_instance, expected_class)
  24. @pytest.mark.parametrize(
  25. "provider_name, config_data, expected_class",
  26. [
  27. ("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
  28. (
  29. "huggingface",
  30. {"model": "sentence-transformers/all-mpnet-base-v2"},
  31. embedchain.embedder.huggingface.HuggingFaceEmbedder,
  32. ),
  33. ("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAIEmbedder),
  34. ("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
  35. ],
  36. )
  37. def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
  38. mocker.patch("embedchain.embedder.vertexai.VertexAIEmbedder", autospec=True)
  39. embedder_instance = EmbedderFactory.create(provider_name, config_data)
  40. assert isinstance(embedder_instance, expected_class)
  41. @pytest.mark.parametrize(
  42. "provider_name, config_data, expected_class",
  43. [
  44. ("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
  45. (
  46. "opensearch",
  47. {"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
  48. embedchain.vectordb.opensearch.OpenSearchDB,
  49. ),
  50. ("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
  51. ],
  52. )
  53. def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
  54. mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
  55. vectordb_instance = VectorDBFactory.create(provider_name, config_data)
  56. assert isinstance(vectordb_instance, expected_class)