factory.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import importlib
  2. def load_class(class_type):
  3. module_path, class_name = class_type.rsplit(".", 1)
  4. module = importlib.import_module(module_path)
  5. return getattr(module, class_name)
  6. class LlmFactory:
  7. provider_to_class = {
  8. "anthropic": "embedchain.llm.anthropic.AnthropicLlm",
  9. "azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
  10. "cohere": "embedchain.llm.cohere.CohereLlm",
  11. "together": "embedchain.llm.together.TogetherLlm",
  12. "gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
  13. "ollama": "embedchain.llm.ollama.OllamaLlm",
  14. "huggingface": "embedchain.llm.huggingface.HuggingFaceLlm",
  15. "jina": "embedchain.llm.jina.JinaLlm",
  16. "llama2": "embedchain.llm.llama2.Llama2Llm",
  17. "openai": "embedchain.llm.openai.OpenAILlm",
  18. "vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
  19. "google": "embedchain.llm.google.GoogleLlm",
  20. "aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
  21. "mistralai": "embedchain.llm.mistralai.MistralAILlm",
  22. "clarifai": "embedchain.llm.clarifai.ClarifaiLlm",
  23. "groq": "embedchain.llm.groq.GroqLlm",
  24. "nvidia": "embedchain.llm.nvidia.NvidiaLlm",
  25. "vllm": "embedchain.llm.vllm.VLLM",
  26. }
  27. provider_to_config_class = {
  28. "embedchain": "embedchain.config.llm.base.BaseLlmConfig",
  29. "openai": "embedchain.config.llm.base.BaseLlmConfig",
  30. "anthropic": "embedchain.config.llm.base.BaseLlmConfig",
  31. }
  32. @classmethod
  33. def create(cls, provider_name, config_data):
  34. class_type = cls.provider_to_class.get(provider_name)
  35. # Default to embedchain base config if the provider is not in the config map
  36. config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
  37. config_class_type = cls.provider_to_config_class.get(config_name)
  38. if class_type:
  39. llm_class = load_class(class_type)
  40. llm_config_class = load_class(config_class_type)
  41. return llm_class(config=llm_config_class(**config_data))
  42. else:
  43. raise ValueError(f"Unsupported Llm provider: {provider_name}")
  44. class EmbedderFactory:
  45. provider_to_class = {
  46. "azure_openai": "embedchain.embedder.azure_openai.AzureOpenAIEmbedder",
  47. "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
  48. "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
  49. "openai": "embedchain.embedder.openai.OpenAIEmbedder",
  50. "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
  51. "google": "embedchain.embedder.google.GoogleAIEmbedder",
  52. "mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
  53. "clarifai": "embedchain.embedder.clarifai.ClarifaiEmbedder",
  54. "nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
  55. "cohere": "embedchain.embedder.cohere.CohereEmbedder",
  56. "ollama": "embedchain.embedder.ollama.OllamaEmbedder",
  57. }
  58. provider_to_config_class = {
  59. "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
  60. "google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
  61. "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
  62. "huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
  63. "clarifai": "embedchain.config.embedder.base.BaseEmbedderConfig",
  64. "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
  65. "ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
  66. }
  67. @classmethod
  68. def create(cls, provider_name, config_data):
  69. class_type = cls.provider_to_class.get(provider_name)
  70. # Default to openai config if the provider is not in the config map
  71. config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
  72. config_class_type = cls.provider_to_config_class.get(config_name)
  73. if class_type:
  74. embedder_class = load_class(class_type)
  75. embedder_config_class = load_class(config_class_type)
  76. return embedder_class(config=embedder_config_class(**config_data))
  77. else:
  78. raise ValueError(f"Unsupported Embedder provider: {provider_name}")
  79. class VectorDBFactory:
  80. provider_to_class = {
  81. "chroma": "embedchain.vectordb.chroma.ChromaDB",
  82. "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
  83. "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
  84. "lancedb": "embedchain.vectordb.lancedb.LanceDB",
  85. "pinecone": "embedchain.vectordb.pinecone.PineconeDB",
  86. "qdrant": "embedchain.vectordb.qdrant.QdrantDB",
  87. "weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
  88. "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
  89. }
  90. provider_to_config_class = {
  91. "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
  92. "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
  93. "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
  94. "lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig",
  95. "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
  96. "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
  97. "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
  98. "zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
  99. }
  100. @classmethod
  101. def create(cls, provider_name, config_data):
  102. class_type = cls.provider_to_class.get(provider_name)
  103. config_class_type = cls.provider_to_config_class.get(provider_name)
  104. if class_type:
  105. embedder_class = load_class(class_type)
  106. embedder_config_class = load_class(config_class_type)
  107. return embedder_class(config=embedder_config_class(**config_data))
  108. else:
  109. raise ValueError(f"Unsupported Embedder provider: {provider_name}")