factory.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import importlib
  2. def load_class(class_type):
  3. module_path, class_name = class_type.rsplit(".", 1)
  4. module = importlib.import_module(module_path)
  5. return getattr(module, class_name)
  6. class LlmFactory:
  7. provider_to_class = {
  8. "anthropic": "embedchain.llm.anthropic.AnthropicLlm",
  9. "azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
  10. "cohere": "embedchain.llm.cohere.CohereLlm",
  11. "gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
  12. "huggingface": "embedchain.llm.huggingface.HuggingFaceLlm",
  13. "jina": "embedchain.llm.jina.JinaLlm",
  14. "llama2": "embedchain.llm.llama2.Llama2Llm",
  15. "openai": "embedchain.llm.openai.OpenAILlm",
  16. "vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
  17. }
  18. provider_to_config_class = {
  19. "embedchain": "embedchain.config.llm.base.BaseLlmConfig",
  20. "openai": "embedchain.config.llm.base.BaseLlmConfig",
  21. "anthropic": "embedchain.config.llm.base.BaseLlmConfig",
  22. }
  23. @classmethod
  24. def create(cls, provider_name, config_data):
  25. class_type = cls.provider_to_class.get(provider_name)
  26. # Default to embedchain base config if the provider is not in the config map
  27. config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
  28. config_class_type = cls.provider_to_config_class.get(config_name)
  29. if class_type:
  30. llm_class = load_class(class_type)
  31. llm_config_class = load_class(config_class_type)
  32. return llm_class(config=llm_config_class(**config_data))
  33. else:
  34. raise ValueError(f"Unsupported Llm provider: {provider_name}")
  35. class EmbedderFactory:
  36. provider_to_class = {
  37. "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
  38. "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
  39. "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
  40. "openai": "embedchain.embedder.openai.OpenAIEmbedder",
  41. }
  42. provider_to_config_class = {
  43. "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
  44. }
  45. @classmethod
  46. def create(cls, provider_name, config_data):
  47. class_type = cls.provider_to_class.get(provider_name)
  48. # Default to openai config if the provider is not in the config map
  49. config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
  50. config_class_type = cls.provider_to_config_class.get(config_name)
  51. if class_type:
  52. embedder_class = load_class(class_type)
  53. embedder_config_class = load_class(config_class_type)
  54. return embedder_class(config=embedder_config_class(**config_data))
  55. else:
  56. raise ValueError(f"Unsupported Embedder provider: {provider_name}")
  57. class VectorDBFactory:
  58. provider_to_class = {
  59. "chroma": "embedchain.vectordb.chroma.ChromaDB",
  60. "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
  61. "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
  62. }
  63. provider_to_config_class = {
  64. "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
  65. "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
  66. "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
  67. }
  68. @classmethod
  69. def create(cls, provider_name, config_data):
  70. class_type = cls.provider_to_class.get(provider_name)
  71. config_class_type = cls.provider_to_config_class.get(provider_name)
  72. if class_type:
  73. embedder_class = load_class(class_type)
  74. embedder_config_class = load_class(config_class_type)
  75. return embedder_class(config=embedder_config_class(**config_data))
  76. else:
  77. raise ValueError(f"Unsupported Embedder provider: {provider_name}")