Ver Fonte

fix: llama2 - use config with specific defaults (#594)

cachho há 1 ano atrás
pai
commit
2cb47938fd
1 ficheiros alterados com 19 adições e 2 exclusões
  1. 19 2
      embedchain/llm/llama2.py

+ 19 - 2
embedchain/llm/llama2.py

@@ -13,6 +13,19 @@ class Llama2Llm(BaseLlm):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
         if "REPLICATE_API_TOKEN" not in os.environ:
             raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
+
+        # Set default config values specific to this llm
+        if not config:
+            config = BaseLlmConfig()
+            # Add variables to this block that have a default value in the parent class
+            config.max_tokens = 500
+            config.temperature = 0.75
+        # Add variables that are `none` by default to this block.
+        if not config.model:
+            config.model = (
+                "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5"
+            )
+
         super().__init__(config=config)
 
     def get_llm_model_answer(self, prompt):
@@ -20,7 +33,11 @@ class Llama2Llm(BaseLlm):
         if self.config.system_prompt:
             raise ValueError("Llama2App does not support `system_prompt`")
         llm = Replicate(
-            model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
-            input={"temperature": self.config.temperature or 0.75, "max_length": 500, "top_p": self.config.top_p},
+            model=self.config.model,
+            input={
+                "temperature": self.config.temperature,
+                "max_length": self.config.max_tokens,
+                "top_p": self.config.top_p,
+            },
         )
         return llm(prompt)