Quellcode durchsuchen

Add Ollama as a supported embedding provider (#1344)

Colin O'Brien vor 1 Jahr
Ursprung
Commit
a795798156

+ 3 - 2
configs/ollama.yaml

@@ -8,6 +8,7 @@ llm:
     base_url: http://localhost:11434
 
 embedder:
-  provider: huggingface
+  provider: ollama
   config:
-    model: 'BAAI/bge-small-en-v1.5'
+    model: 'mxbai-embed-large:latest'
+    base_url: http://localhost:11434

+ 1 - 0
embedchain/config/__init__.py

@@ -6,6 +6,7 @@ from .base_config import BaseConfig
 from .cache_config import CacheConfig
 from .embedder.base import BaseEmbedderConfig
 from .embedder.base import BaseEmbedderConfig as EmbedderConfig
+from .embedder.ollama import OllamaEmbedderConfig
 from .llm.base import BaseLlmConfig
 from .vectordb.chroma import ChromaDbConfig
 from .vectordb.elasticsearch import ElasticsearchDBConfig

+ 15 - 0
embedchain/config/embedder/ollama.py

@@ -0,0 +1,15 @@
+from typing import Optional
+
+from embedchain.config.embedder.base import BaseEmbedderConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class OllamaEmbedderConfig(BaseEmbedderConfig):
+    def __init__(
+        self,
+        model: Optional[str] = None,
+        base_url: Optional[str] = None,
+    ):
+        super().__init__(model)
+        self.base_url = base_url or "http://127.0.0.1:11434"

BIN
embedchain/embedder/.ollama.py.swp


+ 19 - 0
embedchain/embedder/ollama.py

@@ -0,0 +1,19 @@
+from typing import Optional
+
+from langchain_community.embeddings import OllamaEmbeddings
+
+from embedchain.config import OllamaEmbedderConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.models import VectorDimensions
+
+
+class OllamaEmbedder(BaseEmbedder):
+    def __init__(self, config: Optional[OllamaEmbedderConfig] = None):
+        super().__init__(config=config)
+
+        embeddings = OllamaEmbeddings(model=self.config.model, base_url=self.config.base_url)
+        embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
+        self.set_embedding_fn(embedding_fn=embedding_fn)
+
+        vector_dimension = self.config.vector_dimension or VectorDimensions.OLLAMA.value
+        self.set_vector_dimension(vector_dimension=vector_dimension)

+ 2 - 0
embedchain/factory.py

@@ -58,6 +58,7 @@ class EmbedderFactory:
         "mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
         "nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
         "cohere": "embedchain.embedder.cohere.CohereEmbedder",
+        "ollama": "embedchain.embedder.ollama.OllamaEmbedder",
     }
     provider_to_config_class = {
         "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
@@ -65,6 +66,7 @@ class EmbedderFactory:
         "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
         "huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
         "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
+        "ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
     }
 
     @classmethod

+ 1 - 0
embedchain/models/embedding_functions.py

@@ -6,3 +6,4 @@ class EmbeddingFunctions(Enum):
     HUGGING_FACE = "HUGGING_FACE"
     VERTEX_AI = "VERTEX_AI"
     GPT4ALL = "GPT4ALL"
+    OLLAMA = "OLLAMA"

+ 1 - 0
embedchain/models/vector_dimensions.py

@@ -11,3 +11,4 @@ class VectorDimensions(Enum):
     MISTRAL_AI = 1024
     NVIDIA_AI = 1024
     COHERE = 384
+    OLLAMA = 384

+ 4 - 0
embedchain/utils/misc.py

@@ -449,6 +449,7 @@ def validate_config(config_data):
                     "google",
                     "mistralai",
                     "nvidia",
+                    "ollama",
                 ),
                 Optional("config"): {
                     Optional("model"): Optional(str),
@@ -458,6 +459,7 @@ def validate_config(config_data):
                     Optional("title"): str,
                     Optional("task_type"): str,
                     Optional("vector_dimension"): int,
+                    Optional("base_url"): str,
                 },
             },
             Optional("embedding_model"): {
@@ -470,6 +472,7 @@ def validate_config(config_data):
                     "google",
                     "mistralai",
                     "nvidia",
+                    "ollama",
                 ),
                 Optional("config"): {
                     Optional("model"): str,
@@ -478,6 +481,7 @@ def validate_config(config_data):
                     Optional("title"): str,
                     Optional("task_type"): str,
                     Optional("vector_dimension"): int,
+                    Optional("base_url"): str,
                 },
             },
             Optional("chunker"): {