123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import importlib
- def load_class(class_type):
- module_path, class_name = class_type.rsplit(".", 1)
- module = importlib.import_module(module_path)
- return getattr(module, class_name)
- class LlmFactory:
- provider_to_class = {
- "anthropic": "embedchain.llm.anthropic.AnthropicLlm",
- "azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
- "cohere": "embedchain.llm.cohere.CohereLlm",
- "gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
- "huggingface": "embedchain.llm.huggingface.HuggingFaceLlm",
- "jina": "embedchain.llm.jina.JinaLlm",
- "llama2": "embedchain.llm.llama2.Llama2Llm",
- "openai": "embedchain.llm.openai.OpenAILlm",
- "vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
- }
- provider_to_config_class = {
- "embedchain": "embedchain.config.llm.base.BaseLlmConfig",
- "openai": "embedchain.config.llm.base.BaseLlmConfig",
- "anthropic": "embedchain.config.llm.base.BaseLlmConfig",
- }
- @classmethod
- def create(cls, provider_name, config_data):
- class_type = cls.provider_to_class.get(provider_name)
- # Default to embedchain base config if the provider is not in the config map
- config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
- config_class_type = cls.provider_to_config_class.get(config_name)
- if class_type:
- llm_class = load_class(class_type)
- llm_config_class = load_class(config_class_type)
- return llm_class(config=llm_config_class(**config_data))
- else:
- raise ValueError(f"Unsupported Llm provider: {provider_name}")
- class EmbedderFactory:
- provider_to_class = {
- "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
- "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
- "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
- "openai": "embedchain.embedder.openai.OpenAIEmbedder",
- "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
- }
- provider_to_config_class = {
- "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
- "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
- "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
- }
- @classmethod
- def create(cls, provider_name, config_data):
- class_type = cls.provider_to_class.get(provider_name)
- # Default to openai config if the provider is not in the config map
- config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
- config_class_type = cls.provider_to_config_class.get(config_name)
- if class_type:
- embedder_class = load_class(class_type)
- embedder_config_class = load_class(config_class_type)
- return embedder_class(config=embedder_config_class(**config_data))
- else:
- raise ValueError(f"Unsupported Embedder provider: {provider_name}")
- class VectorDBFactory:
- provider_to_class = {
- "chroma": "embedchain.vectordb.chroma.ChromaDB",
- "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
- "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
- "pinecone": "embedchain.vectordb.pinecone.PineconeDB",
- "qdrant": "embedchain.vectordb.qdrant.QdrantDB",
- "weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
- "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
- }
- provider_to_config_class = {
- "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
- "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
- "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
- "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
- "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
- "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
- "zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
- }
- @classmethod
- def create(cls, provider_name, config_data):
- class_type = cls.provider_to_class.get(provider_name)
- config_class_type = cls.provider_to_config_class.get(provider_name)
- if class_type:
- embedder_class = load_class(class_type)
- embedder_config_class = load_class(config_class_type)
- return embedder_class(config=embedder_config_class(**config_data))
- else:
- raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|