Ver Fonte

chore: linting (#229)

cachho há 2 anos atrás
pai
commit
1c4f629fc0

+ 32 - 10
embedchain/config/ChatConfig.py

@@ -23,26 +23,48 @@ class ChatConfig(QueryConfig):
     """
     Config for the `chat` method, inherits from `QueryConfig`.
     """
-    def __init__(self, template: Template = None, model = None, temperature = None, max_tokens = None, top_p = None, stream: bool = False):
+
+    def __init__(
+        self,
+        template: Template = None,
+        model=None,
+        temperature=None,
+        max_tokens=None,
+        top_p=None,
+        stream: bool = False,
+    ):
         """
         Initializes the ChatConfig instance.
 
-        :param template: Optional. The `Template` instance to use as a template for prompt.
+        :param template: Optional. The `Template` instance to use as a template for
+        prompt.
         :param model: Optional. Controls the OpenAI model used.
-        :param temperature: Optional. Controls the randomness of the model's output. 
-                            Higher values (closer to 1) make output more random, lower values make it more deterministic.
+        :param temperature: Optional. Controls the randomness of the model's output.
+        Higher values (closer to 1) make output more random,lower values make it more
+        deterministic.
         :param max_tokens: Optional. Controls how many tokens are generated.
-        :param top_p: Optional. Controls the diversity of words. Higher values (closer to 1) make word selection more diverse, lower values make words less diverse.
+        :param top_p: Optional. Controls the diversity of words.Higher values
+        (closer to 1) make word selection more diverse, lower values make words less
+        diverse.
         :param stream: Optional. Control if response is streamed back to the user
-        :raises ValueError: If the template is not valid as template should contain $context and $query and $history
+        :raises ValueError: If the template is not valid as template should contain
+        $context and $query and $history
         """
         if template is None:
             template = DEFAULT_PROMPT_TEMPLATE
 
-
-        # History is set as 0 to ensure that there is always a history, that way, there don't have to be two templates.
-        # Having two templates would make it complicated because the history is not user controlled.
-        super().__init__(template, model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p, history=[0], stream=stream)
+        # History is set as 0 to ensure that there is always a history, that way,
+        # there don't have to be two templates. Having two templates would make it
+        # complicated because the history is not user controlled.
+        super().__init__(
+            template,
+            model=model,
+            temperature=temperature,
+            max_tokens=max_tokens,
+            top_p=top_p,
+            history=[0],
+            stream=stream,
+        )
 
     def set_history(self, history):
         """

+ 19 - 6
embedchain/config/QueryConfig.py

@@ -34,21 +34,35 @@ query_re = re.compile(r"\$\{*query\}*")
 context_re = re.compile(r"\$\{*context\}*")
 history_re = re.compile(r"\$\{*history\}*")
 
+
 class QueryConfig(BaseConfig):
     """
     Config for the `query` method.
     """
 
-    def __init__(self, template: Template = None, model = None, temperature = None, max_tokens = None, top_p = None, history = None, stream: bool = False):
+    def __init__(
+        self,
+        template: Template = None,
+        model=None,
+        temperature=None,
+        max_tokens=None,
+        top_p=None,
+        history=None,
+        stream: bool = False,
+    ):
         """
         Initializes the QueryConfig instance.
 
-        :param template: Optional. The `Template` instance to use as a template for prompt.
+        :param template: Optional. The `Template` instance to use as a template for
+        prompt.
         :param model: Optional. Controls the OpenAI model used.
-        :param temperature: Optional. Controls the randomness of the model's output. 
-                            Higher values (closer to 1) make output more random, lower values make it more deterministic.
+        :param temperature: Optional. Controls the randomness of the model's output.
+        Higher values (closer to 1) make output more random, lower values make it more
+        deterministic.
         :param max_tokens: Optional. Controls how many tokens are generated.
-        :param top_p: Optional. Controls the diversity of words. Higher values (closer to 1) make word selection more diverse, lower values make words less diverse.
+        :param top_p: Optional. Controls the diversity of words. Higher values
+        (closer to 1) make word selection more diverse, lower values make words less
+        diverse.
         :param history: Optional. A list of strings to consider as history.
         :param stream: Optional. Control if response is streamed back to user
         :raises ValueError: If the template is not valid as template should
@@ -68,7 +82,6 @@ class QueryConfig(BaseConfig):
             else:
                 template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
 
-      
         self.temperature = temperature if temperature else 0
         self.max_tokens = max_tokens if max_tokens else 1000
         self.model = model if model else "gpt-3.5-turbo-0613"

+ 9 - 5
embedchain/embedchain.py

@@ -114,7 +114,11 @@ class EmbedChain:
         chunks_before_addition = self.count()
         self.collection.add(documents=documents, metadatas=list(metadatas), ids=ids)
         print(
-            f"Successfully saved {src}. New chunks count: {self.count() - chunks_before_addition}")  # noqa:E501
+            (
+                f"Successfully saved {src}. New chunks count: "
+                f"{self.count() - chunks_before_addition}"
+            )
+        )
 
     def _format_result(self, results):
         return [
@@ -310,12 +314,12 @@ class App(EmbedChain):
         messages = []
         messages.append({"role": "user", "content": prompt})
         response = openai.ChatCompletion.create(
-            model = config.model,
+            model=config.model,
             messages=messages,
-            temperature = config.temperature,
-            max_tokens = config.max_tokens,
+            temperature=config.temperature,
+            max_tokens=config.max_tokens,
             top_p=config.top_p,
-            stream=config.stream
+            stream=config.stream,
         )
 
         if config.stream: