Sfoglia il codice sorgente

Refactor openai embedder (#1444)

Vatsal Rathod 1 anno fa
parent
commit
edaeb78ccf
3 ha cambiato i file con 32 aggiunte e 16 eliminazioni
  1. 22 0
      embedchain/embedder/azure_openai.py
  2. 9 15
      embedchain/embedder/openai.py
  3. 1 1
      embedchain/factory.py

+ 22 - 0
embedchain/embedder/azure_openai.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain_community.embeddings import AzureOpenAIEmbeddings
+
+from embedchain.config import BaseEmbedderConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.models import VectorDimensions
+
+
+class AzureOpenAIEmbedder(BaseEmbedder):
+    def __init__(self, config: Optional[BaseEmbedderConfig] = None):
+        super().__init__(config=config)
+
+        if self.config.model is None:
+            self.config.model = "text-embedding-ada-002"
+
+        embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
+        embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
+
+        self.set_embedding_fn(embedding_fn=embedding_fn)
+        vector_dimension = self.config.vector_dimension or VectorDimensions.OPENAI.value
+        self.set_vector_dimension(vector_dimension=vector_dimension)

+ 9 - 15
embedchain/embedder/openai.py

@@ -2,7 +2,7 @@ import os
 from typing import Optional
 
 from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
-from langchain_openai.embeddings import AzureOpenAIEmbeddings
+
 
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder
@@ -19,20 +19,14 @@ class OpenAIEmbedder(BaseEmbedder):
         api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
         api_base = self.config.api_base or os.environ.get("OPENAI_API_BASE")
 
-        if self.config.deployment_name:
-            embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
-            embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
-        else:
-            if api_key is None and os.getenv("OPENAI_ORGANIZATION") is None:
-                raise ValueError(
-                    "OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
-                )  # noqa:E501
-            embedding_fn = OpenAIEmbeddingFunction(
-                api_key=api_key,
-                api_base=api_base,
-                organization_id=os.getenv("OPENAI_ORGANIZATION"),
-                model_name=self.config.model,
-            )
+        if api_key is None and os.getenv("OPENAI_ORGANIZATION") is None:
+            raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided")  # noqa:E501
+        embedding_fn = OpenAIEmbeddingFunction(
+            api_key=api_key,
+            api_base=api_base,
+            organization_id=os.getenv("OPENAI_ORGANIZATION"),
+            model_name=self.config.model,
+        )
         self.set_embedding_fn(embedding_fn=embedding_fn)
         vector_dimension = self.config.vector_dimension or VectorDimensions.OPENAI.value
         self.set_vector_dimension(vector_dimension=vector_dimension)

+ 1 - 1
embedchain/factory.py

@@ -50,7 +50,7 @@ class LlmFactory:
 
 class EmbedderFactory:
     provider_to_class = {
-        "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
+        "azure_openai": "embedchain.embedder.azure_openai.AzureOpenAIEmbedder",
         "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
         "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
         "openai": "embedchain.embedder.openai.OpenAIEmbedder",