Forráskód Böngészése

[Bugfix] fix config validation for google llm config (#1088)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 éve
szülő
commit
a304ded500
2 módosított fájl, 9 hozzáadás és 2 törlés
  1. 5 2
      embedchain/embedchain.py
  2. 4 0
      embedchain/utils.py

+ 5 - 2
embedchain/embedchain.py

@@ -7,7 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 
-from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
+from embedchain.cache import (adapt, get_gptcache_session,
+                              gptcache_data_convert,
+                              gptcache_update_cache_callback)
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
@@ -17,7 +19,8 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+                                         IndirectDataType, SpecialDataType)
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.utils import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB

+ 4 - 0
embedchain/utils.py

@@ -420,6 +420,8 @@ def validate_config(config_data):
                     Optional("model"): Optional(str),
                     Optional("deployment_name"): Optional(str),
                     Optional("api_key"): str,
+                    Optional("title"): str,
+                    Optional("task_type"): str,
                 },
             },
             Optional("embedding_model"): {
@@ -428,6 +430,8 @@ def validate_config(config_data):
                     Optional("model"): str,
                     Optional("deployment_name"): str,
                     Optional("api_key"): str,
+                    Optional("title"): str,
+                    Optional("task_type"): str,
                 },
             },
             Optional("chunker"): {