Browse Source

[Bug fix] Fix Azure OpenAI related issue (#973)

Deshraj Yadav 1 year ago
parent
commit
33409140b4
4 changed files with 4 additions and 5 deletions
  1. 0 1
      configs/chroma.yaml
  2. 2 2
      embedchain/embedder/openai.py
  3. 1 1
      pyproject.toml
  4. 1 1
      tests/apps/test_apps.py

+ 0 - 1
configs/chroma.yaml

@@ -23,4 +23,3 @@ embedder:
   provider: openai
   config:
     model: 'text-embedding-ada-002'
-    deployment_name: 'test-deployment'

+ 2 - 2
embedchain/embedder/openai.py

@@ -2,7 +2,7 @@ import os
 from typing import Optional
 
 from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
-from langchain.embeddings import OpenAIEmbeddings
+from langchain.embeddings import AzureOpenAIEmbeddings
 
 from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder
@@ -17,7 +17,7 @@ class OpenAIEmbedder(BaseEmbedder):
             self.config.model = "text-embedding-ada-002"
 
         if self.config.deployment_name:
-            embeddings = OpenAIEmbeddings(deployment=self.config.deployment_name)
+            embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
             embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         else:
             if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.19"
+version = "0.1.20"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 1 - 1
tests/apps/test_apps.py

@@ -119,7 +119,7 @@ class TestAppFromConfig:
         # Validate the Embedder config values
         embedder_config = config_data["embedder"]["config"]
         assert app.embedder.config.model == embedder_config["model"]
-        assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
+        assert app.embedder.config.deployment_name == embedder_config.get("deployment_name")
 
     def test_from_opensource_config(self, mocker):
         mocker.patch("embedchain.vectordb.chroma.chromadb.Client")