Ver Fonte

Use either `embedder` or `embedding_model` as YAML key (#905)

Sidharth Mohanty há 1 ano atrás
pai
commit
9c4b023297

+ 3 - 3
embedchain/apps/app.py

@@ -130,7 +130,7 @@ class App(EmbedChain):
         app_config_data = config_data.get("app", {})
         llm_config_data = config_data.get("llm", {})
         db_config_data = config_data.get("vectordb", {})
-        embedder_config_data = config_data.get("embedder", {})
+        embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
 
         app_config = AppConfig(**app_config_data.get("config", {}))
 
@@ -140,6 +140,6 @@ class App(EmbedChain):
         db_provider = db_config_data.get("provider", "chroma")
         db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
 
-        embedder_provider = embedder_config_data.get("provider", "openai")
-        embedder = EmbedderFactory.create(embedder_provider, embedder_config_data.get("config", {}))
+        embedder_provider = embedding_model_config_data.get("provider", "openai")
+        embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
         return cls(config=app_config, llm=llm, db=db, embedder=embedder)

+ 1 - 1
embedchain/pipeline.py

@@ -359,7 +359,7 @@ class Pipeline(EmbedChain):
 
         pipeline_config_data = config_data.get("app", {}).get("config", {})
         db_config_data = config_data.get("vectordb", {})
-        embedding_model_config_data = config_data.get("embedding_model", {})
+        embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         llm_config_data = config_data.get("llm", {})
 
         pipeline_config = PipelineConfig(**pipeline_config_data)

+ 2 - 2
examples/rest-api/requirements.txt

@@ -1,6 +1,6 @@
 fastapi==0.104.0
 uvicorn==0.23.2
-embedchain==0.0.90
-embedchain[streamlit, community, opensource, elasticsearch, opensearch, poe, discord, slack, whatsapp, weaviate, pinecone, qdrant, images, huggingface_hub, cohere, milvus, dataloaders, vertexai, llama2, gmail, json]==0.0.90
+embedchain==0.0.91
+embedchain[streamlit, community, opensource, elasticsearch, opensearch, poe, discord, slack, whatsapp, weaviate, pinecone, qdrant, images, huggingface_hub, cohere, milvus, dataloaders, vertexai, llama2, gmail, json]==0.0.91
 sqlalchemy==2.0.22
 python-multipart==0.0.6

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.0.90"
+version = "0.0.91"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",