Browse Source

Download Ollama model if not present (#1397)

Dev Khant 1 year ago
parent
commit
73e53aaff1
3 changed files with 13 additions and 1 deletions
  1. 1 1
      Makefile
  2. 10 0
      embedchain/llm/ollama.py
  3. 2 0
      tests/llm/test_ollama.py

+ 1 - 1
Makefile

@@ -11,7 +11,7 @@ install:
 
 install_all:
 	poetry install --all-extras
-	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]"
+	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama
 
 install_es:
 	poetry install --extras elasticsearch

+ 10 - 0
embedchain/llm/ollama.py

@@ -1,3 +1,4 @@
+import logging
 from collections.abc import Iterable
 from typing import Optional, Union
 
@@ -5,11 +6,14 @@ from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.stdout import StdOutCallbackHandler
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain_community.llms.ollama import Ollama
+from ollama import Client
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class OllamaLlm(BaseLlm):
@@ -18,6 +22,12 @@ class OllamaLlm(BaseLlm):
         if self.config.model is None:
             self.config.model = "llama2"
 
+        client = Client(host=config.base_url)
+        local_models = client.list()["models"]
+        if not any(model.get("name") == self.config.model for model in local_models):
+            logger.info(f"Pulling {self.config.model} from Ollama!")
+            client.pull(self.config.model)
+
     def get_llm_model_answer(self, prompt):
         return self._get_answer(prompt=prompt, config=self.config)
 

+ 2 - 0
tests/llm/test_ollama.py

@@ -11,6 +11,7 @@ def ollama_llm_config():
 
 
 def test_get_llm_model_answer(ollama_llm_config, mocker):
+    mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
     mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
 
     llm = OllamaLlm(ollama_llm_config)
@@ -20,6 +21,7 @@ def test_get_llm_model_answer(ollama_llm_config, mocker):
 
 
 def test_get_answer_mocked_ollama(ollama_llm_config, mocker):
+    mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
     mocked_ollama = mocker.patch("embedchain.llm.ollama.Ollama")
     mock_instance = mocked_ollama.return_value
     mock_instance.invoke.return_value = "Mocked answer"