소스 검색

Show details for query tokens (#1392)

Dev Khant 1 년 전
부모
커밋
4880557d51

+ 1 - 1
Makefile

@@ -11,7 +11,7 @@ install:
 
 install_all:
 	poetry install --all-extras
-	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface psutil
+	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama langchain_together==0.1.3 langchain_cohere==0.1.5 deepgram-sdk==3.2.7 langchain-huggingface psutil
 
 install_es:
 	poetry install --extras elasticsearch

+ 1 - 0
docs/api-reference/advanced/configuration.mdx

@@ -209,6 +209,7 @@ Alright, let's dive into what each key means in the yaml config above:
         - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
         - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `online` (Boolean): Controls whether to use internet to get more context for answering query (set to false).
+        - `token_usage` (Boolean): Controls whether to use token usage for the querying models (set to false).
         - `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables.
         - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
         - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1

+ 46 - 0
docs/components/llms.mdx

@@ -840,6 +840,52 @@ answer = app.query("What is the net worth of Elon Musk today?")
 ```
 </CodeGroup>
 
+## Token Usage
+
+You can get the cost of the query by setting `token_usage` to `True` in the config file. This will return the token details: `input_tokens`, `output_tokens`, `total_cost`.
+The list of paid LLMs that support token usage are:
+- OpenAI
+- Vertex AI
+- Anthropic
+- Cohere
+- Together
+- Groq
+- Mistral AI
+- NVIDIA AI
+
+Here is an example of how to use token usage:
+<CodeGroup>
+ 
+```python main.py
+os.environ["OPENAI_API_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+
+app.add("https://www.forbes.com/profile/elon-musk")
+
+response, token_usage = app.query("what is the net worth of Elon Musk?")
+# Elon Musk's net worth is $209.9 billion as of 6/9/24.
+# {'input_tokens': 1228, 'output_tokens': 21, 'total_cost (USD)': 0.001884}
+
+
+response, token_usage = app.chat("Which companies did Elon Musk found?")
+# Elon Musk founded six companies, including Tesla, which is an electric car maker, SpaceX, a rocket producer, and the Boring Company, a tunneling startup.
+# {'input_tokens': 1616, 'output_tokens': 34, 'total_cost (USD)': 0.002492}
+```
+  
+```yaml config.yaml
+llm:
+  provider: openai
+  config:
+    model: gpt-3.5-turbo
+    temperature: 0.5
+    max_tokens: 1000
+    token_usage: true
+```
+</CodeGroup>
+
+If a model is missing and you'd like to add it to `model_prices_and_context_window.json`, please feel free to open a PR.
+
 <br/ >
 
 <Snippet file="missing-llm-tip.mdx" />

+ 10 - 0
embedchain/config/llm/base.py

@@ -1,3 +1,4 @@
+import json
 import logging
 import re
 from string import Template
@@ -92,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
         top_p: float = 1,
         stream: bool = False,
         online: bool = False,
+        token_usage: bool = False,
         deployment_name: Optional[str] = None,
         system_prompt: Optional[str] = None,
         where: dict[str, Any] = None,
@@ -135,6 +137,8 @@ class BaseLlmConfig(BaseConfig):
         :type stream: bool, optional
         :param online: Controls whether to use internet for answering query, defaults to False
         :type online: bool, optional
+        :param token_usage: Controls whether to return token usage in response, defaults to False
+        :type token_usage: bool, optional
         :param deployment_name: t.b.a., defaults to None
         :type deployment_name: Optional[str], optional
         :param system_prompt: System prompt string, defaults to None
@@ -180,6 +184,8 @@ class BaseLlmConfig(BaseConfig):
         self.max_tokens = max_tokens
         self.model = model
         self.top_p = top_p
+        self.online = online
+        self.token_usage = token_usage
         self.deployment_name = deployment_name
         self.system_prompt = system_prompt
         self.query_type = query_type
@@ -197,6 +203,10 @@ class BaseLlmConfig(BaseConfig):
         self.online = online
         self.api_version = api_version
 
+        if token_usage:
+            f = open("model_prices_and_context_window.json")
+            self.model_pricing_map = json.load(f)
+
         if isinstance(prompt, str):
             prompt = Template(prompt)
 

+ 44 - 19
embedchain/embedchain.py

@@ -6,9 +6,7 @@ from typing import Any, Optional, Union
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 
-from embedchain.cache import (adapt, get_gptcache_session,
-                              gptcache_data_convert,
-                              gptcache_update_cache_callback)
+from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
@@ -18,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import (DataType, DirectDataType,
-                                         IndirectDataType, SpecialDataType)
+from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
 from embedchain.utils.misc import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
 
@@ -478,7 +475,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[dict] = None,
         citations: bool = False,
         **kwargs: dict[str, Any],
-    ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
+    ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -501,7 +498,9 @@ class EmbedChain(JSONSerializable):
         :type kwargs: dict[str, Any]
         :return: The answer to the query, with citations if the citation flag is True
         or the dry run result
-        :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
+        :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
+        tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
+        tuple[str, list[tuple[str,str,str]], dict[str, Any]]
         """
         contexts = self._retrieve_from_database(
             input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -524,17 +523,29 @@ class EmbedChain(JSONSerializable):
                 dry_run=dry_run,
             )
         else:
-            answer = self.llm.query(
-                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-            )
+            if self.llm.config.token_usage:
+                answer, token_info = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
+            else:
+                answer = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
 
         # Send anonymous telemetry
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
 
         if citations:
+            if self.llm.config.token_usage:
+                return {"answer": answer, "contexts": contexts, "usage": token_info}
             return answer, contexts
-        else:
-            return answer
+        if self.llm.config.token_usage:
+            return {"answer": answer, "usage": token_info}
+
+        logger.warning(
+            "Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
+        )
+        return answer
 
     def chat(
         self,
@@ -545,7 +556,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[dict[str, str]] = None,
         citations: bool = False,
         **kwargs: dict[str, Any],
-    ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
+    ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
         """
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -572,7 +583,9 @@ class EmbedChain(JSONSerializable):
         :type kwargs: dict[str, Any]
         :return: The answer to the query, with citations if the citation flag is True
         or the dry run result
-        :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
+        :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
+        tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
+        tuple[str, list[tuple[str,str,str]], dict[str, Any]]
         """
         contexts = self._retrieve_from_database(
             input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -600,9 +613,14 @@ class EmbedChain(JSONSerializable):
             )
         else:
             logger.debug("Cache disabled. Running chat without cache.")
-            answer = self.llm.chat(
-                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-            )
+            if self.llm.config.token_usage:
+                answer, token_info = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
+            else:
+                answer = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
 
         # add conversation in memory
         self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
@@ -611,9 +629,16 @@ class EmbedChain(JSONSerializable):
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
 
         if citations:
+            if self.llm.config.token_usage:
+                return {"answer": answer, "contexts": contexts, "usage": token_info}
             return answer, contexts
-        else:
-            return answer
+        if self.llm.config.token_usage:
+            return {"answer": answer, "usage": token_info}
+
+        logger.warning(
+            "Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
+        )
+        return answer
 
     def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None):
         """

+ 3 - 3
embedchain/embedder/gpt4all.py

@@ -9,10 +9,10 @@ class GPT4AllEmbedder(BaseEmbedder):
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
         super().__init__(config=config)
 
-        from langchain.embeddings import \
-            GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
+        from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
 
-        embeddings = LangchainGPT4AllEmbeddings()
+        model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
+        embeddings = LangchainGPT4AllEmbeddings(model_name=model_name)
         embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         self.set_embedding_fn(embedding_fn=embedding_fn)
 

+ 26 - 4
embedchain/llm/anthropic.py

@@ -1,6 +1,6 @@
 import logging
 import os
-from typing import Optional
+from typing import Any, Optional
 
 try:
     from langchain_anthropic import ChatAnthropic
@@ -21,8 +21,27 @@ class AnthropicLlm(BaseLlm):
         if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
             raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
 
-    def get_llm_model_answer(self, prompt):
-        return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "anthropic/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["input_tokens"],
+                "completion_tokens": token_info["output_tokens"],
+                "total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
@@ -34,4 +53,7 @@ class AnthropicLlm(BaseLlm):
 
         messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
 
-        return chat(messages).content
+        chat_response = chat.invoke(messages)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 14 - 5
embedchain/llm/base.py

@@ -164,7 +164,7 @@ class BaseLlm(JSONSerializable):
         return search.run(input_query)
 
     @staticmethod
-    def _stream_response(answer: Any) -> Generator[Any, Any, None]:
+    def _stream_response(answer: Any, token_info: Optional[dict[str, Any]] = None) -> Generator[Any, Any, None]:
         """Generator to be used as streaming response
 
         :param answer: Answer chunk from llm
@@ -177,6 +177,8 @@ class BaseLlm(JSONSerializable):
             streamed_answer = streamed_answer + chunk
             yield chunk
         logger.info(f"Answer: {streamed_answer}")
+        if token_info:
+            logger.info(f"Token Info: {token_info}")
 
     def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
         """
@@ -219,11 +221,18 @@ class BaseLlm(JSONSerializable):
             if dry_run:
                 return prompt
 
-            answer = self.get_answer_from_llm(prompt)
+            if self.config.token_usage:
+                answer, token_info = self.get_answer_from_llm(prompt)
+            else:
+                answer = self.get_answer_from_llm(prompt)
             if isinstance(answer, str):
                 logger.info(f"Answer: {answer}")
+                if self.config.token_usage:
+                    return answer, token_info
                 return answer
             else:
+                if self.config.token_usage:
+                    return self._stream_response(answer, token_info)
                 return self._stream_response(answer)
         finally:
             if config:
@@ -276,13 +285,13 @@ class BaseLlm(JSONSerializable):
             if dry_run:
                 return prompt
 
-            answer = self.get_answer_from_llm(prompt)
+            answer, token_info = self.get_answer_from_llm(prompt)
             if isinstance(answer, str):
                 logger.info(f"Answer: {answer}")
-                return answer
+                return answer, token_info
             else:
                 # this is a streamed response and needs to be handled differently.
-                return self._stream_response(answer)
+                return self._stream_response(answer, token_info)
         finally:
             if config:
                 # Restore previous config

+ 38 - 15
embedchain/llm/cohere.py

@@ -1,8 +1,8 @@
 import importlib
 import os
-from typing import Optional
+from typing import Any, Optional
 
-from langchain_community.llms.cohere import Cohere
+from langchain_cohere import ChatCohere
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -17,27 +17,50 @@ class CohereLlm(BaseLlm):
         except ModuleNotFoundError:
             raise ModuleNotFoundError(
                 "The required dependencies for Cohere are not installed."
-                'Please install with `pip install --upgrade "embedchain[cohere]"`'
+                "Please install with `pip install langchain_cohere==1.16.0`"
             ) from None
 
         super().__init__(config=config)
         if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
             raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
 
-    def get_llm_model_answer(self, prompt):
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
         if self.config.system_prompt:
             raise ValueError("CohereLlm does not support `system_prompt`")
-        return CohereLlm._get_answer(prompt=prompt, config=self.config)
+
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "cohere/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["input_tokens"],
+                "completion_tokens": token_info["output_tokens"],
+                "total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        api_key = config.api_key or os.getenv("COHERE_API_KEY")
-        llm = Cohere(
-            cohere_api_key=api_key,
-            model=config.model,
-            max_tokens=config.max_tokens,
-            temperature=config.temperature,
-            p=config.top_p,
-        )
-
-        return llm.invoke(prompt)
+        api_key = config.api_key or os.environ["COHERE_API_KEY"]
+        kwargs = {
+            "model_name": config.model or "command-r",
+            "temperature": config.temperature,
+            "max_tokens": config.max_tokens,
+            "together_api_key": api_key,
+        }
+
+        chat = ChatCohere(**kwargs)
+        chat_response = chat.invoke(prompt)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_count"]
+        return chat_response.content

+ 27 - 5
embedchain/llm/groq.py

@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain.schema import HumanMessage, SystemMessage
@@ -22,9 +22,27 @@ class GroqLlm(BaseLlm):
         if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
             raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
 
-    def get_llm_model_answer(self, prompt) -> str:
-        response = self._get_answer(prompt, self.config)
-        return response
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "groq/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
         messages = []
@@ -42,4 +60,8 @@ class GroqLlm(BaseLlm):
             chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
         else:
             chat = ChatGroq(**kwargs)
-        return chat.invoke(messages).content
+
+        chat_response = chat.invoke(prompt)
+        if self.config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 26 - 6
embedchain/llm/mistralai.py

@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -13,8 +13,27 @@ class MistralAILlm(BaseLlm):
         if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
             raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
 
-    def get_llm_model_answer(self, prompt):
-        return MistralAILlm._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "mistralai/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig):
@@ -47,6 +66,7 @@ class MistralAILlm(BaseLlm):
                 answer += chunk.content
             return answer
         else:
-            response = client.invoke(**kwargs, input=messages)
-            answer = response.content
-            return answer
+            chat_response = client.invoke(**kwargs, input=messages)
+            if config.token_usage:
+                return chat_response.content, chat_response.response_metadata["token_usage"]
+            return chat_response.content

+ 26 - 4
embedchain/llm/nvidia.py

@@ -1,6 +1,6 @@
 import os
 from collections.abc import Iterable
-from typing import Optional, Union
+from typing import Any, Optional, Union
 
 from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.stdout import StdOutCallbackHandler
@@ -25,8 +25,27 @@ class NvidiaLlm(BaseLlm):
         if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
             raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
 
-    def get_llm_model_answer(self, prompt):
-        return self._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "nvidia/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["input_tokens"],
+                "completion_tokens": token_info["output_tokens"],
+                "total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
@@ -43,4 +62,7 @@ class NvidiaLlm(BaseLlm):
         if labels:
             params["labels"] = labels
         llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager))
-        return llm.invoke(prompt).content if labels is None else llm.invoke(prompt, labels=labels).content
+        chat_response = llm.invoke(prompt) if labels is None else llm.invoke(prompt, labels=labels)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 26 - 4
embedchain/llm/openai.py

@@ -23,9 +23,28 @@ class OpenAILlm(BaseLlm):
         self.tools = tools
         super().__init__(config=config)
 
-    def get_llm_model_answer(self, prompt) -> str:
-        response = self._get_answer(prompt, self.config)
-        return response
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "openai/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+
+        return self._get_answer(prompt, self.config)
 
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
         messages = []
@@ -66,7 +85,10 @@ class OpenAILlm(BaseLlm):
         if self.tools:
             return self._query_function_call(chat, self.tools, messages)
 
-        return chat.invoke(messages).content
+        chat_response = chat.invoke(messages)
+        if self.config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content
 
     def _query_function_call(
         self,

+ 42 - 14
embedchain/llm/together.py

@@ -1,8 +1,13 @@
 import importlib
 import os
-from typing import Optional
+from typing import Any, Optional
 
-from langchain_community.llms import Together
+try:
+    from langchain_together import ChatTogether
+except ImportError:
+    raise ImportError(
+        "Please install the langchain_together package by running `pip install langchain_together==0.1.3`."
+    )
 
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -24,20 +29,43 @@ class TogetherLlm(BaseLlm):
         if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
             raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
 
-    def get_llm_model_answer(self, prompt):
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
         if self.config.system_prompt:
             raise ValueError("TogetherLlm does not support `system_prompt`")
-        return TogetherLlm._get_answer(prompt=prompt, config=self.config)
+
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "together/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        api_key = config.api_key or os.getenv("TOGETHER_API_KEY")
-        llm = Together(
-            together_api_key=api_key,
-            model=config.model,
-            max_tokens=config.max_tokens,
-            temperature=config.temperature,
-            top_p=config.top_p,
-        )
-
-        return llm.invoke(prompt)
+        api_key = config.api_key or os.environ["TOGETHER_API_KEY"]
+        kwargs = {
+            "model_name": config.model or "mixtral-8x7b-32768",
+            "temperature": config.temperature,
+            "max_tokens": config.max_tokens,
+            "together_api_key": api_key,
+        }
+
+        chat = ChatTogether(**kwargs)
+        chat_response = chat.invoke(prompt)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 29 - 6
embedchain/llm/vertex_ai.py

@@ -1,6 +1,6 @@
 import importlib
 import logging
-from typing import Optional
+from typing import Any, Optional
 
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain_google_vertexai import ChatVertexAI
@@ -24,16 +24,35 @@ class VertexAILlm(BaseLlm):
             ) from None
         super().__init__(config=config)
 
-    def get_llm_model_answer(self, prompt):
-        return VertexAILlm._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "vertexai/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_token_count"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info[
+                "candidates_token_count"
+            ]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_token_count"],
+                "completion_tokens": token_info["candidates_token_count"],
+                "total_tokens": token_info["prompt_token_count"] + token_info["candidates_token_count"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
         if config.top_p and config.top_p != 1:
             logger.warning("Config option `top_p` is not supported by this model.")
 
-        messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
-
         if config.stream:
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
             llm = ChatVertexAI(
@@ -42,4 +61,8 @@ class VertexAILlm(BaseLlm):
         else:
             llm = ChatVertexAI(temperature=config.temperature, model=config.model)
 
-        return llm.invoke(messages).content
+        messages = VertexAILlm._get_messages(prompt)
+        chat_response = llm.invoke(messages)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["usage_metadata"]
+        return chat_response.content

+ 1 - 0
embedchain/utils/misc.py

@@ -428,6 +428,7 @@ def validate_config(config_data):
                     Optional("top_p"): Or(float, int),
                     Optional("stream"): bool,
                     Optional("online"): bool,
+                    Optional("token_usage"): bool,
                     Optional("template"): str,
                     Optional("prompt"): str,
                     Optional("system_prompt"): str,

+ 803 - 0
model_prices_and_context_window.json

@@ -0,0 +1,803 @@
+{
+    "openai/gpt-4": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "openai/gpt-4o": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000005,
+        "output_cost_per_token": 0.000015
+    },
+    "openai/gpt-4o-2024-05-13": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000005,
+        "output_cost_per_token": 0.000015
+    },
+    "openai/gpt-4-turbo-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-0314": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "openai/gpt-4-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "openai/gpt-4-32k": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "openai/gpt-4-32k-0314": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "openai/gpt-4-32k-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "openai/gpt-4-turbo": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-turbo-2024-04-09": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-1106-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-0125-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-3.5-turbo": {
+        "max_tokens": 4097,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-0301": {
+        "max_tokens": 4097,
+        "max_input_tokens": 4097,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-0613": {
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-1106": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000010,
+        "output_cost_per_token": 0.0000020
+    },
+    "openai/gpt-3.5-turbo-0125": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000005,
+        "output_cost_per_token": 0.0000015
+    },
+    "openai/gpt-3.5-turbo-16k": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "openai/gpt-3.5-turbo-16k-0613": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "openai/text-embedding-3-large": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "output_vector_size": 3072,
+        "input_cost_per_token": 0.00000013,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/text-embedding-3-small": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "output_vector_size": 1536,
+        "input_cost_per_token": 0.00000002,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/text-embedding-ada-002": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "output_vector_size": 1536, 
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/text-embedding-ada-002-v2": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/babbage-002": {
+        "max_tokens": 16384,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000004,
+        "output_cost_per_token": 0.0000004
+    },
+    "openai/davinci-002": {
+        "max_tokens": 16384,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000002,
+        "output_cost_per_token": 0.000002
+    },    
+    "openai/gpt-3.5-turbo-instruct": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-instruct-0914": {
+        "max_tokens": 4097,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4097,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/gpt-4o": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000005,
+        "output_cost_per_token": 0.000015
+    },
+    "azure/gpt-4-turbo-2024-04-09": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-0125-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-1106-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "azure/gpt-4-32k-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "azure/gpt-4-32k": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "azure/gpt-4": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "azure/gpt-4-turbo": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-turbo-vision-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-3.5-turbo-16k-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "azure/gpt-3.5-turbo-1106": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/gpt-3.5-turbo-0125": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000005,
+        "output_cost_per_token": 0.0000015
+    },
+    "azure/gpt-3.5-turbo-16k": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "azure/gpt-3.5-turbo": {
+        "max_tokens": 4096,
+        "max_input_tokens": 4097,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000005,
+        "output_cost_per_token": 0.0000015
+    },
+    "azure/gpt-3.5-turbo-instruct-0914": {
+        "max_tokens": 4097,
+        "max_input_tokens": 4097,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/gpt-3.5-turbo-instruct": {
+        "max_tokens": 4097,
+        "max_input_tokens": 4097,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/text-embedding-ada-002": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.000000
+    },
+    "azure/text-embedding-3-large": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.00000013,
+        "output_cost_per_token": 0.000000
+    },
+    "azure/text-embedding-3-small": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.00000002,
+        "output_cost_per_token": 0.000000
+    }, 
+    "mistralai/mistral-tiny": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000025
+    },
+    "mistralai/mistral-small": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/mistral-small-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/mistral-medium": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000027,
+        "output_cost_per_token": 0.0000081
+    },
+    "mistralai/mistral-medium-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000027,
+        "output_cost_per_token": 0.0000081
+    },
+    "mistralai/mistral-medium-2312": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000027,
+        "output_cost_per_token": 0.0000081
+    },
+    "mistralai/mistral-large-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000004,
+        "output_cost_per_token": 0.000012
+    },
+    "mistralai/mistral-large-2402": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000004,
+        "output_cost_per_token": 0.000012
+    },
+    "mistralai/open-mistral-7b": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000025
+    },
+    "mistralai/open-mixtral-8x7b": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000007,
+        "output_cost_per_token": 0.0000007
+    },
+    "mistralai/open-mixtral-8x22b": {
+        "max_tokens": 8191,
+        "max_input_tokens": 64000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000002,
+        "output_cost_per_token": 0.000006
+    },
+    "mistralai/codestral-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/codestral-2405": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/mistral-embed": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.0
+    },
+    "groq/llama2-70b-4096": {
+        "max_tokens": 4096,
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000070,
+        "output_cost_per_token": 0.00000080
+    },
+    "groq/llama3-8b-8192": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000010,
+        "output_cost_per_token": 0.00000010
+    },
+    "groq/llama3-70b-8192": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000064,
+        "output_cost_per_token": 0.00000080
+    },
+    "groq/mixtral-8x7b-32768": {
+        "max_tokens": 32768,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 32768,
+        "input_cost_per_token": 0.00000027,
+        "output_cost_per_token": 0.00000027
+    },
+    "groq/gemma-7b-it": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000010,
+        "output_cost_per_token": 0.00000010
+    },
+    "anthropic/claude-instant-1": {
+        "max_tokens": 8191,
+        "max_input_tokens": 100000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.00000163,
+        "output_cost_per_token": 0.00000551
+    },
+    "anthropic/claude-instant-1.2": {
+        "max_tokens": 8191,
+        "max_input_tokens": 100000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000000163,
+        "output_cost_per_token": 0.000000551
+    },
+    "anthropic/claude-2": {
+        "max_tokens": 8191,
+        "max_input_tokens": 100000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000008,
+        "output_cost_per_token": 0.000024
+    },
+    "anthropic/claude-2.1": {
+        "max_tokens": 8191,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000008,
+        "output_cost_per_token": 0.000024
+    },
+    "anthropic/claude-3-haiku-20240307": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000125
+    },
+    "anthropic/claude-3-opus-20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000075
+    },
+    "anthropic/claude-3-sonnet-20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000015
+    },
+    "vertexai/chat-bison": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/chat-bison@001": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/chat-bison@002": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/chat-bison-32k": {
+        "max_tokens": 8192,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-bison": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-bison@001": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-gecko@001": {
+        "max_tokens": 64,
+        "max_input_tokens": 2048,
+        "max_output_tokens": 64,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-gecko@002": {
+        "max_tokens": 64,
+        "max_input_tokens": 2048,
+        "max_output_tokens": 64,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-gecko": {
+        "max_tokens": 64,
+        "max_input_tokens": 2048,
+        "max_output_tokens": 64,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/codechat-bison": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/codechat-bison@001": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/codechat-bison-32k": {
+        "max_tokens": 8192,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/gemini-pro": {
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-001": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-002": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.5-pro": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-flash-001": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0, 
+        "output_cost_per_token": 0
+    },
+    "vertexai/gemini-1.5-flash-preview-0514": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0, 
+        "output_cost_per_token": 0
+    },
+    "vertexai/gemini-1.5-pro-001": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-pro-preview-0514": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-pro-preview-0215": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-pro-preview-0409": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-experimental": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0,
+        "output_cost_per_token": 0
+    },
+    "vertexai/gemini-pro-vision": {
+        "max_tokens": 2048,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 2048,
+        "max_images_per_prompt": 16,
+        "max_videos_per_prompt": 1,
+        "max_video_length": 2,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-vision": {
+        "max_tokens": 2048,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 2048,
+        "max_images_per_prompt": 16,
+        "max_videos_per_prompt": 1,
+        "max_video_length": 2,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-vision-001": {
+        "max_tokens": 2048,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 2048,
+        "max_images_per_prompt": 16,
+        "max_videos_per_prompt": 1,
+        "max_video_length": 2,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/claude-3-sonnet@20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000015
+    },
+    "vertexai/claude-3-haiku@20240307": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000125
+    },
+    "vertexai/claude-3-opus@20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000075
+    },
+    "cohere/command-r": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000050,
+        "output_cost_per_token": 0.0000015
+    },
+    "cohere/command-light": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+    "cohere/command-r-plus": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000015
+    },
+    "cohere/command-nightly": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+     "cohere/command": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+     "cohere/command-medium-beta": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+     "cohere/command-xlarge-beta": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+    "together/together-ai-up-to-3b": {
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.0000001
+    },
+    "together/together-ai-3.1b-7b": {
+        "input_cost_per_token": 0.0000002,
+        "output_cost_per_token": 0.0000002
+    },
+    "together/together-ai-7.1b-20b": {
+        "max_tokens": 1000,
+        "input_cost_per_token": 0.0000004,
+        "output_cost_per_token": 0.0000004
+    },
+    "together/together-ai-20.1b-40b": {
+        "input_cost_per_token": 0.0000008,
+        "output_cost_per_token": 0.0000008
+    },
+    "together/together-ai-40.1b-70b": {
+        "input_cost_per_token": 0.0000009,
+        "output_cost_per_token": 0.0000009
+    },
+    "together/mistralai/Mixtral-8x7B-Instruct-v0.1": {
+        "input_cost_per_token": 0.0000006,
+        "output_cost_per_token": 0.0000006
+    }
+}

파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 441 - 380
poetry.lock


+ 2 - 2
pyproject.toml

@@ -156,6 +156,7 @@ langchain-google-vertexai = { version = "^1.0.6", optional = true }
 sqlalchemy = "^2.0.27"
 alembic = "^1.13.1"
 langchain-cohere = "^0.1.4"
+langchain-community = "^0.2.6"
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
@@ -183,9 +184,8 @@ slack = ["slack-sdk", "flask"]
 whatsapp = ["twilio", "flask"]
 weaviate = ["weaviate-client"]
 qdrant = ["qdrant-client"]
-huggingface_hub=["huggingface_hub"]
-cohere = ["cohere"]
 together = ["together"]
+huggingface_hub=["huggingface_hub"]
 milvus = ["pymilvus"]
 dataloaders=[
     "youtube-transcript-api",

+ 23 - 2
tests/llm/test_anthrophic.py

@@ -11,7 +11,7 @@ from embedchain.llm.anthropic import AnthropicLlm
 @pytest.fixture
 def anthropic_llm():
     os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
-    config = BaseLlmConfig(temperature=0.5, model="gpt2")
+    config = BaseLlmConfig(temperature=0.5, model="claude-instant-1", token_usage=False)
     return AnthropicLlm(config)
 
 
@@ -20,7 +20,7 @@ def test_get_llm_model_answer(anthropic_llm):
         prompt = "Test Prompt"
         response = anthropic_llm.get_llm_model_answer(prompt)
         assert response == "Test Response"
-        mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config)
+        mock_method.assert_called_once_with(prompt, anthropic_llm.config)
 
 
 def test_get_messages(anthropic_llm):
@@ -31,3 +31,24 @@ def test_get_messages(anthropic_llm):
         SystemMessage(content="Test System Prompt", additional_kwargs={}),
         HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
     ]
+
+
+def test_get_llm_model_answer_with_token_usage(anthropic_llm):
+    test_config = BaseLlmConfig(
+        temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model, token_usage=True
+    )
+    anthropic_llm.config = test_config
+    with patch.object(
+        AnthropicLlm, "_get_answer", return_value=("Test Response", {"input_tokens": 1, "output_tokens": 2})
+    ) as mock_method:
+        prompt = "Test Prompt"
+        response, token_info = anthropic_llm.get_llm_model_answer(prompt)
+        assert response == "Test Response"
+        assert token_info == {
+            "prompt_tokens": 1,
+            "completion_tokens": 2,
+            "total_tokens": 3,
+            "total_cost": 1.265e-05,
+            "cost_currency": "USD",
+        }
+        mock_method.assert_called_once_with(prompt, anthropic_llm.config)

+ 29 - 4
tests/llm/test_cohere.py

@@ -9,7 +9,7 @@ from embedchain.llm.cohere import CohereLlm
 @pytest.fixture
 def cohere_llm_config():
     os.environ["COHERE_API_KEY"] = "test_api_key"
-    config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
+    config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False)
     yield config
     os.environ.pop("COHERE_API_KEY")
 
@@ -36,10 +36,35 @@ def test_get_llm_model_answer(cohere_llm_config, mocker):
     assert answer == "Test answer"
 
 
+def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker):
+    test_config = BaseLlmConfig(
+        temperature=cohere_llm_config.temperature,
+        max_tokens=cohere_llm_config.max_tokens,
+        top_p=cohere_llm_config.top_p,
+        model=cohere_llm_config.model,
+        token_usage=True,
+    )
+    mocker.patch(
+        "embedchain.llm.cohere.CohereLlm._get_answer",
+        return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}),
+    )
+
+    llm = CohereLlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 3.5e-06,
+        "cost_currency": "USD",
+    }
+
+
 def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
-    mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere")
-    mock_instance = mocked_cohere.return_value
-    mock_instance.invoke.return_value = "Mocked answer"
+    mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere")
+    mocked_cohere.return_value.invoke.return_value.content = "Mocked answer"
 
     llm = CohereLlm(cohere_llm_config)
     prompt = "Test query"

+ 31 - 4
tests/llm/test_mistralai.py

@@ -24,7 +24,7 @@ def test_mistralai_llm_init(monkeypatch):
 
 
 def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
-    def mock_get_answer(prompt, config):
+    def mock_get_answer(self, prompt, config):
         return "Generated Text"
 
     monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
@@ -36,7 +36,7 @@ def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
 
 def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
     mistralai_llm_config.system_prompt = "Test system prompt"
-    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
     llm = MistralAILlm(config=mistralai_llm_config)
     result = llm.get_llm_model_answer("test prompt")
 
@@ -44,7 +44,7 @@ def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_conf
 
 
 def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
-    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
     llm = MistralAILlm(config=mistralai_llm_config)
     result = llm.get_llm_model_answer("")
 
@@ -53,8 +53,35 @@ def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
 
 def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
     mistralai_llm_config.system_prompt = None
-    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
     llm = MistralAILlm(config=mistralai_llm_config)
     result = llm.get_llm_model_answer("test prompt")
 
     assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
+    test_config = BaseLlmConfig(
+        temperature=mistralai_llm_config.temperature,
+        max_tokens=mistralai_llm_config.max_tokens,
+        top_p=mistralai_llm_config.top_p,
+        model=mistralai_llm_config.model,
+        token_usage=True,
+    )
+    monkeypatch.setattr(
+        MistralAILlm,
+        "_get_answer",
+        lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
+    )
+
+    llm = MistralAILlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Generated Text"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 7.5e-07,
+        "cost_currency": "USD",
+    }

+ 29 - 0
tests/llm/test_openai.py

@@ -62,6 +62,35 @@ def test_get_llm_model_answer_empty_prompt(config, mocker):
     mocked_get_answer.assert_called_once_with("", config)
 
 
+def test_get_llm_model_answer_with_token_usage(config, mocker):
+    test_config = BaseLlmConfig(
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        top_p=config.top_p,
+        stream=config.stream,
+        system_prompt=config.system_prompt,
+        model=config.model,
+        token_usage=True,
+    )
+    mocked_get_answer = mocker.patch(
+        "embedchain.llm.openai.OpenAILlm._get_answer",
+        return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
+    )
+
+    llm = OpenAILlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 5.5e-06,
+        "cost_currency": "USD",
+    }
+    mocked_get_answer.assert_called_once_with("Test query", test_config)
+
+
 def test_get_llm_model_answer_with_streaming(config, mocker):
     config.stream = True
     mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")

+ 29 - 3
tests/llm/test_together.py

@@ -9,7 +9,7 @@ from embedchain.llm.together import TogetherLlm
 @pytest.fixture
 def together_llm_config():
     os.environ["TOGETHER_API_KEY"] = "test_api_key"
-    config = BaseLlmConfig(model="togethercomputer/RedPajama-INCITE-7B-Base", max_tokens=50, temperature=0.7, top_p=0.8)
+    config = BaseLlmConfig(model="together-ai-up-to-3b", max_tokens=50, temperature=0.7, top_p=0.8)
     yield config
     os.environ.pop("TOGETHER_API_KEY")
 
@@ -36,10 +36,36 @@ def test_get_llm_model_answer(together_llm_config, mocker):
     assert answer == "Test answer"
 
 
+def test_get_llm_model_answer_with_token_usage(together_llm_config, mocker):
+    test_config = BaseLlmConfig(
+        temperature=together_llm_config.temperature,
+        max_tokens=together_llm_config.max_tokens,
+        top_p=together_llm_config.top_p,
+        model=together_llm_config.model,
+        token_usage=True,
+    )
+    mocker.patch(
+        "embedchain.llm.together.TogetherLlm._get_answer",
+        return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
+    )
+
+    llm = TogetherLlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 3e-07,
+        "cost_currency": "USD",
+    }
+
+
 def test_get_answer_mocked_together(together_llm_config, mocker):
-    mocked_together = mocker.patch("embedchain.llm.together.Together")
+    mocked_together = mocker.patch("embedchain.llm.together.ChatTogether")
     mock_instance = mocked_together.return_value
-    mock_instance.invoke.return_value = "Mocked answer"
+    mock_instance.invoke.return_value.content = "Mocked answer"
 
     llm = TogetherLlm(together_llm_config)
     prompt = "Test query"

+ 26 - 1
tests/llm/test_vertex_ai.py

@@ -24,7 +24,32 @@ def test_get_llm_model_answer(vertexai_llm):
         prompt = "Test Prompt"
         response = vertexai_llm.get_llm_model_answer(prompt)
         assert response == "Test Response"
-        mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
+        mock_method.assert_called_once_with(prompt, vertexai_llm.config)
+
+
+def test_get_llm_model_answer_with_token_usage(vertexai_llm):
+    test_config = BaseLlmConfig(
+        temperature=vertexai_llm.config.temperature,
+        max_tokens=vertexai_llm.config.max_tokens,
+        top_p=vertexai_llm.config.top_p,
+        model=vertexai_llm.config.model,
+        token_usage=True,
+    )
+    vertexai_llm.config = test_config
+    with patch.object(
+        VertexAILlm,
+        "_get_answer",
+        return_value=("Test Response", {"prompt_token_count": 1, "candidates_token_count": 2}),
+    ):
+        response, token_info = vertexai_llm.get_llm_model_answer("Test Query")
+        assert response == "Test Response"
+        assert token_info == {
+            "prompt_tokens": 1,
+            "completion_tokens": 2,
+            "total_tokens": 3,
+            "total_cost": 3.75e-07,
+            "cost_currency": "USD",
+        }
 
 
 @patch("embedchain.llm.vertex_ai.ChatVertexAI")

이 변경점에서 너무 많은 파일들이 변경되어 몇몇 파일들은 표시되지 않았습니다.