factory.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import importlib
  2. def load_class(class_type):
  3. module_path, class_name = class_type.rsplit(".", 1)
  4. module = importlib.import_module(module_path)
  5. return getattr(module, class_name)
  6. class LlmFactory:
  7. provider_to_class = {
  8. "anthropic": "embedchain.llm.anthropic.AnthropicLlm",
  9. "azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
  10. "cohere": "embedchain.llm.cohere.CohereLlm",
  11. "gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
  12. "ollama": "embedchain.llm.ollama.OllamaLlm",
  13. "huggingface": "embedchain.llm.huggingface.HuggingFaceLlm",
  14. "jina": "embedchain.llm.jina.JinaLlm",
  15. "llama2": "embedchain.llm.llama2.Llama2Llm",
  16. "openai": "embedchain.llm.openai.OpenAILlm",
  17. "vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
  18. "google": "embedchain.llm.google.GoogleLlm",
  19. }
  20. provider_to_config_class = {
  21. "embedchain": "embedchain.config.llm.base.BaseLlmConfig",
  22. "openai": "embedchain.config.llm.base.BaseLlmConfig",
  23. "anthropic": "embedchain.config.llm.base.BaseLlmConfig",
  24. }
  25. @classmethod
  26. def create(cls, provider_name, config_data):
  27. class_type = cls.provider_to_class.get(provider_name)
  28. # Default to embedchain base config if the provider is not in the config map
  29. config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
  30. config_class_type = cls.provider_to_config_class.get(config_name)
  31. if class_type:
  32. llm_class = load_class(class_type)
  33. llm_config_class = load_class(config_class_type)
  34. return llm_class(config=llm_config_class(**config_data))
  35. else:
  36. raise ValueError(f"Unsupported Llm provider: {provider_name}")
  37. class EmbedderFactory:
  38. provider_to_class = {
  39. "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
  40. "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
  41. "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
  42. "openai": "embedchain.embedder.openai.OpenAIEmbedder",
  43. "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
  44. "google": "embedchain.embedder.google.GoogleAIEmbedder",
  45. }
  46. provider_to_config_class = {
  47. "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
  48. "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
  49. "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
  50. "google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
  51. }
  52. @classmethod
  53. def create(cls, provider_name, config_data):
  54. class_type = cls.provider_to_class.get(provider_name)
  55. # Default to openai config if the provider is not in the config map
  56. config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
  57. config_class_type = cls.provider_to_config_class.get(config_name)
  58. if class_type:
  59. embedder_class = load_class(class_type)
  60. embedder_config_class = load_class(config_class_type)
  61. return embedder_class(config=embedder_config_class(**config_data))
  62. else:
  63. raise ValueError(f"Unsupported Embedder provider: {provider_name}")
  64. class VectorDBFactory:
  65. provider_to_class = {
  66. "chroma": "embedchain.vectordb.chroma.ChromaDB",
  67. "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
  68. "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
  69. "pinecone": "embedchain.vectordb.pinecone.PineconeDB",
  70. "qdrant": "embedchain.vectordb.qdrant.QdrantDB",
  71. "weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
  72. "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
  73. }
  74. provider_to_config_class = {
  75. "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
  76. "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
  77. "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
  78. "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
  79. "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
  80. "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
  81. "zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
  82. }
  83. @classmethod
  84. def create(cls, provider_name, config_data):
  85. class_type = cls.provider_to_class.get(provider_name)
  86. config_class_type = cls.provider_to_config_class.get(provider_name)
  87. if class_type:
  88. embedder_class = load_class(class_type)
  89. embedder_config_class = load_class(config_class_type)
  90. return embedder_class(config=embedder_config_class(**config_data))
  91. else:
  92. raise ValueError(f"Unsupported Embedder provider: {provider_name}")