Browse Source

Show details for query tokens (#1392)

Dev Khant 1 year ago
parent
commit
4880557d51

+ 1 - 1
Makefile

@@ -11,7 +11,7 @@ install:
 
 
 install_all:
 install_all:
 	poetry install --all-extras
 	poetry install --all-extras
-	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface psutil
+	poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama langchain_together==0.1.3 langchain_cohere==0.1.5 deepgram-sdk==3.2.7 langchain-huggingface psutil
 
 
 install_es:
 install_es:
 	poetry install --extras elasticsearch
 	poetry install --extras elasticsearch

+ 1 - 0
docs/api-reference/advanced/configuration.mdx

@@ -209,6 +209,7 @@ Alright, let's dive into what each key means in the yaml config above:
         - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
         - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
         - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `online` (Boolean): Controls whether to use internet to get more context for answering query (set to false).
         - `online` (Boolean): Controls whether to use internet to get more context for answering query (set to false).
+        - `token_usage` (Boolean): Controls whether to use token usage for the querying models (set to false).
         - `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables.
         - `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables.
         - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
         - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
         - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
         - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1

+ 46 - 0
docs/components/llms.mdx

@@ -840,6 +840,52 @@ answer = app.query("What is the net worth of Elon Musk today?")
 ```
 ```
 </CodeGroup>
 </CodeGroup>
 
 
+## Token Usage
+
+You can get the cost of the query by setting `token_usage` to `True` in the config file. This will return the token details: `input_tokens`, `output_tokens`, `total_cost`.
+The list of paid LLMs that support token usage are:
+- OpenAI
+- Vertex AI
+- Anthropic
+- Cohere
+- Together
+- Groq
+- Mistral AI
+- NVIDIA AI
+
+Here is an example of how to use token usage:
+<CodeGroup>
+ 
+```python main.py
+os.environ["OPENAI_API_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+
+app.add("https://www.forbes.com/profile/elon-musk")
+
+response, token_usage = app.query("what is the net worth of Elon Musk?")
+# Elon Musk's net worth is $209.9 billion as of 6/9/24.
+# {'input_tokens': 1228, 'output_tokens': 21, 'total_cost (USD)': 0.001884}
+
+
+response, token_usage = app.chat("Which companies did Elon Musk found?")
+# Elon Musk founded six companies, including Tesla, which is an electric car maker, SpaceX, a rocket producer, and the Boring Company, a tunneling startup.
+# {'input_tokens': 1616, 'output_tokens': 34, 'total_cost (USD)': 0.002492}
+```
+  
+```yaml config.yaml
+llm:
+  provider: openai
+  config:
+    model: gpt-3.5-turbo
+    temperature: 0.5
+    max_tokens: 1000
+    token_usage: true
+```
+</CodeGroup>
+
+If a model is missing and you'd like to add it to `model_prices_and_context_window.json`, please feel free to open a PR.
+
 <br/ >
 <br/ >
 
 
 <Snippet file="missing-llm-tip.mdx" />
 <Snippet file="missing-llm-tip.mdx" />

+ 10 - 0
embedchain/config/llm/base.py

@@ -1,3 +1,4 @@
+import json
 import logging
 import logging
 import re
 import re
 from string import Template
 from string import Template
@@ -92,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
         top_p: float = 1,
         top_p: float = 1,
         stream: bool = False,
         stream: bool = False,
         online: bool = False,
         online: bool = False,
+        token_usage: bool = False,
         deployment_name: Optional[str] = None,
         deployment_name: Optional[str] = None,
         system_prompt: Optional[str] = None,
         system_prompt: Optional[str] = None,
         where: dict[str, Any] = None,
         where: dict[str, Any] = None,
@@ -135,6 +137,8 @@ class BaseLlmConfig(BaseConfig):
         :type stream: bool, optional
         :type stream: bool, optional
         :param online: Controls whether to use internet for answering query, defaults to False
         :param online: Controls whether to use internet for answering query, defaults to False
         :type online: bool, optional
         :type online: bool, optional
+        :param token_usage: Controls whether to return token usage in response, defaults to False
+        :type token_usage: bool, optional
         :param deployment_name: t.b.a., defaults to None
         :param deployment_name: t.b.a., defaults to None
         :type deployment_name: Optional[str], optional
         :type deployment_name: Optional[str], optional
         :param system_prompt: System prompt string, defaults to None
         :param system_prompt: System prompt string, defaults to None
@@ -180,6 +184,8 @@ class BaseLlmConfig(BaseConfig):
         self.max_tokens = max_tokens
         self.max_tokens = max_tokens
         self.model = model
         self.model = model
         self.top_p = top_p
         self.top_p = top_p
+        self.online = online
+        self.token_usage = token_usage
         self.deployment_name = deployment_name
         self.deployment_name = deployment_name
         self.system_prompt = system_prompt
         self.system_prompt = system_prompt
         self.query_type = query_type
         self.query_type = query_type
@@ -197,6 +203,10 @@ class BaseLlmConfig(BaseConfig):
         self.online = online
         self.online = online
         self.api_version = api_version
         self.api_version = api_version
 
 
+        if token_usage:
+            f = open("model_prices_and_context_window.json")
+            self.model_pricing_map = json.load(f)
+
         if isinstance(prompt, str):
         if isinstance(prompt, str):
             prompt = Template(prompt)
             prompt = Template(prompt)
 
 

+ 44 - 19
embedchain/embedchain.py

@@ -6,9 +6,7 @@ from typing import Any, Optional, Union
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
 
 
-from embedchain.cache import (adapt, get_gptcache_session,
-                              gptcache_data_convert,
-                              gptcache_update_cache_callback)
+from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
 from embedchain.config.base_app_config import BaseAppConfig
@@ -18,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import (DataType, DirectDataType,
-                                         IndirectDataType, SpecialDataType)
+from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
 from embedchain.utils.misc import detect_datatype, is_valid_json_string
 from embedchain.utils.misc import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.base import BaseVectorDB
 
 
@@ -478,7 +475,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[dict] = None,
         where: Optional[dict] = None,
         citations: bool = False,
         citations: bool = False,
         **kwargs: dict[str, Any],
         **kwargs: dict[str, Any],
-    ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
+    ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
         """
         """
         Queries the vector database based on the given input query.
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -501,7 +498,9 @@ class EmbedChain(JSONSerializable):
         :type kwargs: dict[str, Any]
         :type kwargs: dict[str, Any]
         :return: The answer to the query, with citations if the citation flag is True
         :return: The answer to the query, with citations if the citation flag is True
         or the dry run result
         or the dry run result
-        :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
+        :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
+        tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
+        tuple[str, list[tuple[str,str,str]], dict[str, Any]]
         """
         """
         contexts = self._retrieve_from_database(
         contexts = self._retrieve_from_database(
             input_query=input_query, config=config, where=where, citations=citations, **kwargs
             input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -524,17 +523,29 @@ class EmbedChain(JSONSerializable):
                 dry_run=dry_run,
                 dry_run=dry_run,
             )
             )
         else:
         else:
-            answer = self.llm.query(
-                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-            )
+            if self.llm.config.token_usage:
+                answer, token_info = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
+            else:
+                answer = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
 
 
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
         self.telemetry.capture(event_name="query", properties=self._telemetry_props)
 
 
         if citations:
         if citations:
+            if self.llm.config.token_usage:
+                return {"answer": answer, "contexts": contexts, "usage": token_info}
             return answer, contexts
             return answer, contexts
-        else:
-            return answer
+        if self.llm.config.token_usage:
+            return {"answer": answer, "usage": token_info}
+
+        logger.warning(
+            "Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
+        )
+        return answer
 
 
     def chat(
     def chat(
         self,
         self,
@@ -545,7 +556,7 @@ class EmbedChain(JSONSerializable):
         where: Optional[dict[str, str]] = None,
         where: Optional[dict[str, str]] = None,
         citations: bool = False,
         citations: bool = False,
         **kwargs: dict[str, Any],
         **kwargs: dict[str, Any],
-    ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
+    ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
         """
         """
         Queries the vector database on the given input query.
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -572,7 +583,9 @@ class EmbedChain(JSONSerializable):
         :type kwargs: dict[str, Any]
         :type kwargs: dict[str, Any]
         :return: The answer to the query, with citations if the citation flag is True
         :return: The answer to the query, with citations if the citation flag is True
         or the dry run result
         or the dry run result
-        :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
+        :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
+        tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
+        tuple[str, list[tuple[str,str,str]], dict[str, Any]]
         """
         """
         contexts = self._retrieve_from_database(
         contexts = self._retrieve_from_database(
             input_query=input_query, config=config, where=where, citations=citations, **kwargs
             input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -600,9 +613,14 @@ class EmbedChain(JSONSerializable):
             )
             )
         else:
         else:
             logger.debug("Cache disabled. Running chat without cache.")
             logger.debug("Cache disabled. Running chat without cache.")
-            answer = self.llm.chat(
-                input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
-            )
+            if self.llm.config.token_usage:
+                answer, token_info = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
+            else:
+                answer = self.llm.query(
+                    input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
+                )
 
 
         # add conversation in memory
         # add conversation in memory
         self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
         self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
@@ -611,9 +629,16 @@ class EmbedChain(JSONSerializable):
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
 
 
         if citations:
         if citations:
+            if self.llm.config.token_usage:
+                return {"answer": answer, "contexts": contexts, "usage": token_info}
             return answer, contexts
             return answer, contexts
-        else:
-            return answer
+        if self.llm.config.token_usage:
+            return {"answer": answer, "usage": token_info}
+
+        logger.warning(
+            "Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
+        )
+        return answer
 
 
     def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None):
     def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None):
         """
         """

+ 3 - 3
embedchain/embedder/gpt4all.py

@@ -9,10 +9,10 @@ class GPT4AllEmbedder(BaseEmbedder):
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
         super().__init__(config=config)
         super().__init__(config=config)
 
 
-        from langchain.embeddings import \
-            GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
+        from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
 
 
-        embeddings = LangchainGPT4AllEmbeddings()
+        model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
+        embeddings = LangchainGPT4AllEmbeddings(model_name=model_name)
         embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         self.set_embedding_fn(embedding_fn=embedding_fn)
         self.set_embedding_fn(embedding_fn=embedding_fn)
 
 

+ 26 - 4
embedchain/llm/anthropic.py

@@ -1,6 +1,6 @@
 import logging
 import logging
 import os
 import os
-from typing import Optional
+from typing import Any, Optional
 
 
 try:
 try:
     from langchain_anthropic import ChatAnthropic
     from langchain_anthropic import ChatAnthropic
@@ -21,8 +21,27 @@ class AnthropicLlm(BaseLlm):
         if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
         if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
             raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
             raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
 
 
-    def get_llm_model_answer(self, prompt):
-        return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "anthropic/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["input_tokens"],
+                "completion_tokens": token_info["output_tokens"],
+                "total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
 
     @staticmethod
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
@@ -34,4 +53,7 @@ class AnthropicLlm(BaseLlm):
 
 
         messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
         messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
 
 
-        return chat(messages).content
+        chat_response = chat.invoke(messages)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 14 - 5
embedchain/llm/base.py

@@ -164,7 +164,7 @@ class BaseLlm(JSONSerializable):
         return search.run(input_query)
         return search.run(input_query)
 
 
     @staticmethod
     @staticmethod
-    def _stream_response(answer: Any) -> Generator[Any, Any, None]:
+    def _stream_response(answer: Any, token_info: Optional[dict[str, Any]] = None) -> Generator[Any, Any, None]:
         """Generator to be used as streaming response
         """Generator to be used as streaming response
 
 
         :param answer: Answer chunk from llm
         :param answer: Answer chunk from llm
@@ -177,6 +177,8 @@ class BaseLlm(JSONSerializable):
             streamed_answer = streamed_answer + chunk
             streamed_answer = streamed_answer + chunk
             yield chunk
             yield chunk
         logger.info(f"Answer: {streamed_answer}")
         logger.info(f"Answer: {streamed_answer}")
+        if token_info:
+            logger.info(f"Token Info: {token_info}")
 
 
     def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
     def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
         """
         """
@@ -219,11 +221,18 @@ class BaseLlm(JSONSerializable):
             if dry_run:
             if dry_run:
                 return prompt
                 return prompt
 
 
-            answer = self.get_answer_from_llm(prompt)
+            if self.config.token_usage:
+                answer, token_info = self.get_answer_from_llm(prompt)
+            else:
+                answer = self.get_answer_from_llm(prompt)
             if isinstance(answer, str):
             if isinstance(answer, str):
                 logger.info(f"Answer: {answer}")
                 logger.info(f"Answer: {answer}")
+                if self.config.token_usage:
+                    return answer, token_info
                 return answer
                 return answer
             else:
             else:
+                if self.config.token_usage:
+                    return self._stream_response(answer, token_info)
                 return self._stream_response(answer)
                 return self._stream_response(answer)
         finally:
         finally:
             if config:
             if config:
@@ -276,13 +285,13 @@ class BaseLlm(JSONSerializable):
             if dry_run:
             if dry_run:
                 return prompt
                 return prompt
 
 
-            answer = self.get_answer_from_llm(prompt)
+            answer, token_info = self.get_answer_from_llm(prompt)
             if isinstance(answer, str):
             if isinstance(answer, str):
                 logger.info(f"Answer: {answer}")
                 logger.info(f"Answer: {answer}")
-                return answer
+                return answer, token_info
             else:
             else:
                 # this is a streamed response and needs to be handled differently.
                 # this is a streamed response and needs to be handled differently.
-                return self._stream_response(answer)
+                return self._stream_response(answer, token_info)
         finally:
         finally:
             if config:
             if config:
                 # Restore previous config
                 # Restore previous config

+ 38 - 15
embedchain/llm/cohere.py

@@ -1,8 +1,8 @@
 import importlib
 import importlib
 import os
 import os
-from typing import Optional
+from typing import Any, Optional
 
 
-from langchain_community.llms.cohere import Cohere
+from langchain_cohere import ChatCohere
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.helpers.json_serializable import register_deserializable
@@ -17,27 +17,50 @@ class CohereLlm(BaseLlm):
         except ModuleNotFoundError:
         except ModuleNotFoundError:
             raise ModuleNotFoundError(
             raise ModuleNotFoundError(
                 "The required dependencies for Cohere are not installed."
                 "The required dependencies for Cohere are not installed."
-                'Please install with `pip install --upgrade "embedchain[cohere]"`'
+                "Please install with `pip install langchain_cohere==1.16.0`"
             ) from None
             ) from None
 
 
         super().__init__(config=config)
         super().__init__(config=config)
         if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
         if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
             raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
             raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
 
 
-    def get_llm_model_answer(self, prompt):
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
         if self.config.system_prompt:
         if self.config.system_prompt:
             raise ValueError("CohereLlm does not support `system_prompt`")
             raise ValueError("CohereLlm does not support `system_prompt`")
-        return CohereLlm._get_answer(prompt=prompt, config=self.config)
+
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "cohere/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["input_tokens"],
+                "completion_tokens": token_info["output_tokens"],
+                "total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
 
     @staticmethod
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        api_key = config.api_key or os.getenv("COHERE_API_KEY")
-        llm = Cohere(
-            cohere_api_key=api_key,
-            model=config.model,
-            max_tokens=config.max_tokens,
-            temperature=config.temperature,
-            p=config.top_p,
-        )
-
-        return llm.invoke(prompt)
+        api_key = config.api_key or os.environ["COHERE_API_KEY"]
+        kwargs = {
+            "model_name": config.model or "command-r",
+            "temperature": config.temperature,
+            "max_tokens": config.max_tokens,
+            "together_api_key": api_key,
+        }
+
+        chat = ChatCohere(**kwargs)
+        chat_response = chat.invoke(prompt)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_count"]
+        return chat_response.content

+ 27 - 5
embedchain/llm/groq.py

@@ -1,5 +1,5 @@
 import os
 import os
-from typing import Optional
+from typing import Any, Optional
 
 
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain.schema import HumanMessage, SystemMessage
 from langchain.schema import HumanMessage, SystemMessage
@@ -22,9 +22,27 @@ class GroqLlm(BaseLlm):
         if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
         if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
             raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
             raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
 
 
-    def get_llm_model_answer(self, prompt) -> str:
-        response = self._get_answer(prompt, self.config)
-        return response
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "groq/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
 
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
         messages = []
         messages = []
@@ -42,4 +60,8 @@ class GroqLlm(BaseLlm):
             chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
             chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
         else:
         else:
             chat = ChatGroq(**kwargs)
             chat = ChatGroq(**kwargs)
-        return chat.invoke(messages).content
+
+        chat_response = chat.invoke(prompt)
+        if self.config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 26 - 6
embedchain/llm/mistralai.py

@@ -1,5 +1,5 @@
 import os
 import os
-from typing import Optional
+from typing import Any, Optional
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.helpers.json_serializable import register_deserializable
@@ -13,8 +13,27 @@ class MistralAILlm(BaseLlm):
         if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
         if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
             raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
             raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
 
 
-    def get_llm_model_answer(self, prompt):
-        return MistralAILlm._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "mistralai/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
 
     @staticmethod
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig):
     def _get_answer(prompt: str, config: BaseLlmConfig):
@@ -47,6 +66,7 @@ class MistralAILlm(BaseLlm):
                 answer += chunk.content
                 answer += chunk.content
             return answer
             return answer
         else:
         else:
-            response = client.invoke(**kwargs, input=messages)
-            answer = response.content
-            return answer
+            chat_response = client.invoke(**kwargs, input=messages)
+            if config.token_usage:
+                return chat_response.content, chat_response.response_metadata["token_usage"]
+            return chat_response.content

+ 26 - 4
embedchain/llm/nvidia.py

@@ -1,6 +1,6 @@
 import os
 import os
 from collections.abc import Iterable
 from collections.abc import Iterable
-from typing import Optional, Union
+from typing import Any, Optional, Union
 
 
 from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.stdout import StdOutCallbackHandler
 from langchain.callbacks.stdout import StdOutCallbackHandler
@@ -25,8 +25,27 @@ class NvidiaLlm(BaseLlm):
         if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
         if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
             raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
             raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
 
 
-    def get_llm_model_answer(self, prompt):
-        return self._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "nvidia/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["input_tokens"],
+                "completion_tokens": token_info["output_tokens"],
+                "total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
 
     @staticmethod
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
     def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
@@ -43,4 +62,7 @@ class NvidiaLlm(BaseLlm):
         if labels:
         if labels:
             params["labels"] = labels
             params["labels"] = labels
         llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager))
         llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager))
-        return llm.invoke(prompt).content if labels is None else llm.invoke(prompt, labels=labels).content
+        chat_response = llm.invoke(prompt) if labels is None else llm.invoke(prompt, labels=labels)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 26 - 4
embedchain/llm/openai.py

@@ -23,9 +23,28 @@ class OpenAILlm(BaseLlm):
         self.tools = tools
         self.tools = tools
         super().__init__(config=config)
         super().__init__(config=config)
 
 
-    def get_llm_model_answer(self, prompt) -> str:
-        response = self._get_answer(prompt, self.config)
-        return response
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "openai/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+
+        return self._get_answer(prompt, self.config)
 
 
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
     def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
         messages = []
         messages = []
@@ -66,7 +85,10 @@ class OpenAILlm(BaseLlm):
         if self.tools:
         if self.tools:
             return self._query_function_call(chat, self.tools, messages)
             return self._query_function_call(chat, self.tools, messages)
 
 
-        return chat.invoke(messages).content
+        chat_response = chat.invoke(messages)
+        if self.config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content
 
 
     def _query_function_call(
     def _query_function_call(
         self,
         self,

+ 42 - 14
embedchain/llm/together.py

@@ -1,8 +1,13 @@
 import importlib
 import importlib
 import os
 import os
-from typing import Optional
+from typing import Any, Optional
 
 
-from langchain_community.llms import Together
+try:
+    from langchain_together import ChatTogether
+except ImportError:
+    raise ImportError(
+        "Please install the langchain_together package by running `pip install langchain_together==0.1.3`."
+    )
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.helpers.json_serializable import register_deserializable
@@ -24,20 +29,43 @@ class TogetherLlm(BaseLlm):
         if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
         if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
             raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
             raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
 
 
-    def get_llm_model_answer(self, prompt):
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
         if self.config.system_prompt:
         if self.config.system_prompt:
             raise ValueError("TogetherLlm does not support `system_prompt`")
             raise ValueError("TogetherLlm does not support `system_prompt`")
-        return TogetherLlm._get_answer(prompt=prompt, config=self.config)
+
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "together/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_tokens"],
+                "completion_tokens": token_info["completion_tokens"],
+                "total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
 
     @staticmethod
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
-        api_key = config.api_key or os.getenv("TOGETHER_API_KEY")
-        llm = Together(
-            together_api_key=api_key,
-            model=config.model,
-            max_tokens=config.max_tokens,
-            temperature=config.temperature,
-            top_p=config.top_p,
-        )
-
-        return llm.invoke(prompt)
+        api_key = config.api_key or os.environ["TOGETHER_API_KEY"]
+        kwargs = {
+            "model_name": config.model or "mixtral-8x7b-32768",
+            "temperature": config.temperature,
+            "max_tokens": config.max_tokens,
+            "together_api_key": api_key,
+        }
+
+        chat = ChatTogether(**kwargs)
+        chat_response = chat.invoke(prompt)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["token_usage"]
+        return chat_response.content

+ 29 - 6
embedchain/llm/vertex_ai.py

@@ -1,6 +1,6 @@
 import importlib
 import importlib
 import logging
 import logging
-from typing import Optional
+from typing import Any, Optional
 
 
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain_google_vertexai import ChatVertexAI
 from langchain_google_vertexai import ChatVertexAI
@@ -24,16 +24,35 @@ class VertexAILlm(BaseLlm):
             ) from None
             ) from None
         super().__init__(config=config)
         super().__init__(config=config)
 
 
-    def get_llm_model_answer(self, prompt):
-        return VertexAILlm._get_answer(prompt=prompt, config=self.config)
+    def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
+        if self.config.token_usage:
+            response, token_info = self._get_answer(prompt, self.config)
+            model_name = "vertexai/" + self.config.model
+            if model_name not in self.config.model_pricing_map:
+                raise ValueError(
+                    f"Model {model_name} not found in `model_prices_and_context_window.json`. \
+                    You can disable token usage by setting `token_usage` to False."
+                )
+            total_cost = (
+                self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_token_count"]
+            ) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info[
+                "candidates_token_count"
+            ]
+            response_token_info = {
+                "prompt_tokens": token_info["prompt_token_count"],
+                "completion_tokens": token_info["candidates_token_count"],
+                "total_tokens": token_info["prompt_token_count"] + token_info["candidates_token_count"],
+                "total_cost": round(total_cost, 10),
+                "cost_currency": "USD",
+            }
+            return response, response_token_info
+        return self._get_answer(prompt, self.config)
 
 
     @staticmethod
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
         if config.top_p and config.top_p != 1:
         if config.top_p and config.top_p != 1:
             logger.warning("Config option `top_p` is not supported by this model.")
             logger.warning("Config option `top_p` is not supported by this model.")
 
 
-        messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
-
         if config.stream:
         if config.stream:
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
             llm = ChatVertexAI(
             llm = ChatVertexAI(
@@ -42,4 +61,8 @@ class VertexAILlm(BaseLlm):
         else:
         else:
             llm = ChatVertexAI(temperature=config.temperature, model=config.model)
             llm = ChatVertexAI(temperature=config.temperature, model=config.model)
 
 
-        return llm.invoke(messages).content
+        messages = VertexAILlm._get_messages(prompt)
+        chat_response = llm.invoke(messages)
+        if config.token_usage:
+            return chat_response.content, chat_response.response_metadata["usage_metadata"]
+        return chat_response.content

+ 1 - 0
embedchain/utils/misc.py

@@ -428,6 +428,7 @@ def validate_config(config_data):
                     Optional("top_p"): Or(float, int),
                     Optional("top_p"): Or(float, int),
                     Optional("stream"): bool,
                     Optional("stream"): bool,
                     Optional("online"): bool,
                     Optional("online"): bool,
+                    Optional("token_usage"): bool,
                     Optional("template"): str,
                     Optional("template"): str,
                     Optional("prompt"): str,
                     Optional("prompt"): str,
                     Optional("system_prompt"): str,
                     Optional("system_prompt"): str,

+ 803 - 0
model_prices_and_context_window.json

@@ -0,0 +1,803 @@
+{
+    "openai/gpt-4": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "openai/gpt-4o": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000005,
+        "output_cost_per_token": 0.000015
+    },
+    "openai/gpt-4o-2024-05-13": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000005,
+        "output_cost_per_token": 0.000015
+    },
+    "openai/gpt-4-turbo-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-0314": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "openai/gpt-4-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "openai/gpt-4-32k": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "openai/gpt-4-32k-0314": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "openai/gpt-4-32k-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "openai/gpt-4-turbo": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-turbo-2024-04-09": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-1106-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-4-0125-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "openai/gpt-3.5-turbo": {
+        "max_tokens": 4097,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-0301": {
+        "max_tokens": 4097,
+        "max_input_tokens": 4097,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-0613": {
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-1106": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000010,
+        "output_cost_per_token": 0.0000020
+    },
+    "openai/gpt-3.5-turbo-0125": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000005,
+        "output_cost_per_token": 0.0000015
+    },
+    "openai/gpt-3.5-turbo-16k": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "openai/gpt-3.5-turbo-16k-0613": {
+        "max_tokens": 16385,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "openai/text-embedding-3-large": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "output_vector_size": 3072,
+        "input_cost_per_token": 0.00000013,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/text-embedding-3-small": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "output_vector_size": 1536,
+        "input_cost_per_token": 0.00000002,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/text-embedding-ada-002": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "output_vector_size": 1536, 
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/text-embedding-ada-002-v2": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.000000
+    },
+    "openai/babbage-002": {
+        "max_tokens": 16384,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000004,
+        "output_cost_per_token": 0.0000004
+    },
+    "openai/davinci-002": {
+        "max_tokens": 16384,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000002,
+        "output_cost_per_token": 0.000002
+    },    
+    "openai/gpt-3.5-turbo-instruct": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "openai/gpt-3.5-turbo-instruct-0914": {
+        "max_tokens": 4097,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4097,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/gpt-4o": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000005,
+        "output_cost_per_token": 0.000015
+    },
+    "azure/gpt-4-turbo-2024-04-09": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-0125-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-1106-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "azure/gpt-4-32k-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "azure/gpt-4-32k": {
+        "max_tokens": 4096,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00006,
+        "output_cost_per_token": 0.00012
+    },
+    "azure/gpt-4": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00003,
+        "output_cost_per_token": 0.00006
+    },
+    "azure/gpt-4-turbo": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-4-turbo-vision-preview": {
+        "max_tokens": 4096,
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00001,
+        "output_cost_per_token": 0.00003
+    },
+    "azure/gpt-3.5-turbo-16k-0613": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "azure/gpt-3.5-turbo-1106": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/gpt-3.5-turbo-0125": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000005,
+        "output_cost_per_token": 0.0000015
+    },
+    "azure/gpt-3.5-turbo-16k": {
+        "max_tokens": 4096,
+        "max_input_tokens": 16385,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000004
+    },
+    "azure/gpt-3.5-turbo": {
+        "max_tokens": 4096,
+        "max_input_tokens": 4097,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.0000005,
+        "output_cost_per_token": 0.0000015
+    },
+    "azure/gpt-3.5-turbo-instruct-0914": {
+        "max_tokens": 4097,
+        "max_input_tokens": 4097,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/gpt-3.5-turbo-instruct": {
+        "max_tokens": 4097,
+        "max_input_tokens": 4097,
+        "input_cost_per_token": 0.0000015,
+        "output_cost_per_token": 0.000002
+    },
+    "azure/text-embedding-ada-002": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.000000
+    },
+    "azure/text-embedding-3-large": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.00000013,
+        "output_cost_per_token": 0.000000
+    },
+    "azure/text-embedding-3-small": {
+        "max_tokens": 8191,
+        "max_input_tokens": 8191,
+        "input_cost_per_token": 0.00000002,
+        "output_cost_per_token": 0.000000
+    }, 
+    "mistralai/mistral-tiny": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000025
+    },
+    "mistralai/mistral-small": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/mistral-small-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/mistral-medium": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000027,
+        "output_cost_per_token": 0.0000081
+    },
+    "mistralai/mistral-medium-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000027,
+        "output_cost_per_token": 0.0000081
+    },
+    "mistralai/mistral-medium-2312": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000027,
+        "output_cost_per_token": 0.0000081
+    },
+    "mistralai/mistral-large-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000004,
+        "output_cost_per_token": 0.000012
+    },
+    "mistralai/mistral-large-2402": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000004,
+        "output_cost_per_token": 0.000012
+    },
+    "mistralai/open-mistral-7b": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000025
+    },
+    "mistralai/open-mixtral-8x7b": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.0000007,
+        "output_cost_per_token": 0.0000007
+    },
+    "mistralai/open-mixtral-8x22b": {
+        "max_tokens": 8191,
+        "max_input_tokens": 64000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000002,
+        "output_cost_per_token": 0.000006
+    },
+    "mistralai/codestral-latest": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/codestral-2405": {
+        "max_tokens": 8191,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000001,
+        "output_cost_per_token": 0.000003
+    },
+    "mistralai/mistral-embed": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.0
+    },
+    "groq/llama2-70b-4096": {
+        "max_tokens": 4096,
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000070,
+        "output_cost_per_token": 0.00000080
+    },
+    "groq/llama3-8b-8192": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000010,
+        "output_cost_per_token": 0.00000010
+    },
+    "groq/llama3-70b-8192": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000064,
+        "output_cost_per_token": 0.00000080
+    },
+    "groq/mixtral-8x7b-32768": {
+        "max_tokens": 32768,
+        "max_input_tokens": 32768,
+        "max_output_tokens": 32768,
+        "input_cost_per_token": 0.00000027,
+        "output_cost_per_token": 0.00000027
+    },
+    "groq/gemma-7b-it": {
+        "max_tokens": 8192,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000010,
+        "output_cost_per_token": 0.00000010
+    },
+    "anthropic/claude-instant-1": {
+        "max_tokens": 8191,
+        "max_input_tokens": 100000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.00000163,
+        "output_cost_per_token": 0.00000551
+    },
+    "anthropic/claude-instant-1.2": {
+        "max_tokens": 8191,
+        "max_input_tokens": 100000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000000163,
+        "output_cost_per_token": 0.000000551
+    },
+    "anthropic/claude-2": {
+        "max_tokens": 8191,
+        "max_input_tokens": 100000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000008,
+        "output_cost_per_token": 0.000024
+    },
+    "anthropic/claude-2.1": {
+        "max_tokens": 8191,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 8191,
+        "input_cost_per_token": 0.000008,
+        "output_cost_per_token": 0.000024
+    },
+    "anthropic/claude-3-haiku-20240307": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000125
+    },
+    "anthropic/claude-3-opus-20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000075
+    },
+    "anthropic/claude-3-sonnet-20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000015
+    },
+    "vertexai/chat-bison": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/chat-bison@001": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/chat-bison@002": {
+        "max_tokens": 4096,
+        "max_input_tokens": 8192,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/chat-bison-32k": {
+        "max_tokens": 8192,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-bison": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-bison@001": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-gecko@001": {
+        "max_tokens": 64,
+        "max_input_tokens": 2048,
+        "max_output_tokens": 64,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-gecko@002": {
+        "max_tokens": 64,
+        "max_input_tokens": 2048,
+        "max_output_tokens": 64,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/code-gecko": {
+        "max_tokens": 64,
+        "max_input_tokens": 2048,
+        "max_output_tokens": 64,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/codechat-bison": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/codechat-bison@001": {
+        "max_tokens": 1024,
+        "max_input_tokens": 6144,
+        "max_output_tokens": 1024,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/codechat-bison-32k": {
+        "max_tokens": 8192,
+        "max_input_tokens": 32000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000125,
+        "output_cost_per_token": 0.000000125
+    },
+    "vertexai/gemini-pro": {
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-001": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-002": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 32760,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.5-pro": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-flash-001": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0, 
+        "output_cost_per_token": 0
+    },
+    "vertexai/gemini-1.5-flash-preview-0514": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0, 
+        "output_cost_per_token": 0
+    },
+    "vertexai/gemini-1.5-pro-001": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-pro-preview-0514": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-pro-preview-0215": { 
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-1.5-pro-preview-0409": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0.000000625, 
+        "output_cost_per_token": 0.000001875
+    },
+    "vertexai/gemini-experimental": {
+        "max_tokens": 8192,
+        "max_input_tokens": 1000000,
+        "max_output_tokens": 8192,
+        "input_cost_per_token": 0,
+        "output_cost_per_token": 0
+    },
+    "vertexai/gemini-pro-vision": {
+        "max_tokens": 2048,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 2048,
+        "max_images_per_prompt": 16,
+        "max_videos_per_prompt": 1,
+        "max_video_length": 2,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-vision": {
+        "max_tokens": 2048,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 2048,
+        "max_images_per_prompt": 16,
+        "max_videos_per_prompt": 1,
+        "max_video_length": 2,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/gemini-1.0-pro-vision-001": {
+        "max_tokens": 2048,
+        "max_input_tokens": 16384,
+        "max_output_tokens": 2048,
+        "max_images_per_prompt": 16,
+        "max_videos_per_prompt": 1,
+        "max_video_length": 2,
+        "input_cost_per_token": 0.00000025, 
+        "output_cost_per_token": 0.0000005
+    },
+    "vertexai/claude-3-sonnet@20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000015
+    },
+    "vertexai/claude-3-haiku@20240307": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000025,
+        "output_cost_per_token": 0.00000125
+    },
+    "vertexai/claude-3-opus@20240229": {
+        "max_tokens": 4096,
+        "max_input_tokens": 200000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000075
+    },
+    "cohere/command-r": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.00000050,
+        "output_cost_per_token": 0.0000015
+    },
+    "cohere/command-light": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+    "cohere/command-r-plus": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 128000,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000003,
+        "output_cost_per_token": 0.000015
+    },
+    "cohere/command-nightly": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+     "cohere/command": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+     "cohere/command-medium-beta": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+     "cohere/command-xlarge-beta": {
+        "max_tokens": 4096, 
+        "max_input_tokens": 4096,
+        "max_output_tokens": 4096,
+        "input_cost_per_token": 0.000015,
+        "output_cost_per_token": 0.000015
+    },
+    "together/together-ai-up-to-3b": {
+        "input_cost_per_token": 0.0000001,
+        "output_cost_per_token": 0.0000001
+    },
+    "together/together-ai-3.1b-7b": {
+        "input_cost_per_token": 0.0000002,
+        "output_cost_per_token": 0.0000002
+    },
+    "together/together-ai-7.1b-20b": {
+        "max_tokens": 1000,
+        "input_cost_per_token": 0.0000004,
+        "output_cost_per_token": 0.0000004
+    },
+    "together/together-ai-20.1b-40b": {
+        "input_cost_per_token": 0.0000008,
+        "output_cost_per_token": 0.0000008
+    },
+    "together/together-ai-40.1b-70b": {
+        "input_cost_per_token": 0.0000009,
+        "output_cost_per_token": 0.0000009
+    },
+    "together/mistralai/Mixtral-8x7B-Instruct-v0.1": {
+        "input_cost_per_token": 0.0000006,
+        "output_cost_per_token": 0.0000006
+    }
+}

File diff suppressed because it is too large
+ 441 - 380
poetry.lock


+ 2 - 2
pyproject.toml

@@ -156,6 +156,7 @@ langchain-google-vertexai = { version = "^1.0.6", optional = true }
 sqlalchemy = "^2.0.27"
 sqlalchemy = "^2.0.27"
 alembic = "^1.13.1"
 alembic = "^1.13.1"
 langchain-cohere = "^0.1.4"
 langchain-cohere = "^0.1.4"
+langchain-community = "^0.2.6"
 
 
 [tool.poetry.group.dev.dependencies]
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
 black = "^23.3.0"
@@ -183,9 +184,8 @@ slack = ["slack-sdk", "flask"]
 whatsapp = ["twilio", "flask"]
 whatsapp = ["twilio", "flask"]
 weaviate = ["weaviate-client"]
 weaviate = ["weaviate-client"]
 qdrant = ["qdrant-client"]
 qdrant = ["qdrant-client"]
-huggingface_hub=["huggingface_hub"]
-cohere = ["cohere"]
 together = ["together"]
 together = ["together"]
+huggingface_hub=["huggingface_hub"]
 milvus = ["pymilvus"]
 milvus = ["pymilvus"]
 dataloaders=[
 dataloaders=[
     "youtube-transcript-api",
     "youtube-transcript-api",

+ 23 - 2
tests/llm/test_anthrophic.py

@@ -11,7 +11,7 @@ from embedchain.llm.anthropic import AnthropicLlm
 @pytest.fixture
 @pytest.fixture
 def anthropic_llm():
 def anthropic_llm():
     os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
     os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
-    config = BaseLlmConfig(temperature=0.5, model="gpt2")
+    config = BaseLlmConfig(temperature=0.5, model="claude-instant-1", token_usage=False)
     return AnthropicLlm(config)
     return AnthropicLlm(config)
 
 
 
 
@@ -20,7 +20,7 @@ def test_get_llm_model_answer(anthropic_llm):
         prompt = "Test Prompt"
         prompt = "Test Prompt"
         response = anthropic_llm.get_llm_model_answer(prompt)
         response = anthropic_llm.get_llm_model_answer(prompt)
         assert response == "Test Response"
         assert response == "Test Response"
-        mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config)
+        mock_method.assert_called_once_with(prompt, anthropic_llm.config)
 
 
 
 
 def test_get_messages(anthropic_llm):
 def test_get_messages(anthropic_llm):
@@ -31,3 +31,24 @@ def test_get_messages(anthropic_llm):
         SystemMessage(content="Test System Prompt", additional_kwargs={}),
         SystemMessage(content="Test System Prompt", additional_kwargs={}),
         HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
         HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
     ]
     ]
+
+
+def test_get_llm_model_answer_with_token_usage(anthropic_llm):
+    test_config = BaseLlmConfig(
+        temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model, token_usage=True
+    )
+    anthropic_llm.config = test_config
+    with patch.object(
+        AnthropicLlm, "_get_answer", return_value=("Test Response", {"input_tokens": 1, "output_tokens": 2})
+    ) as mock_method:
+        prompt = "Test Prompt"
+        response, token_info = anthropic_llm.get_llm_model_answer(prompt)
+        assert response == "Test Response"
+        assert token_info == {
+            "prompt_tokens": 1,
+            "completion_tokens": 2,
+            "total_tokens": 3,
+            "total_cost": 1.265e-05,
+            "cost_currency": "USD",
+        }
+        mock_method.assert_called_once_with(prompt, anthropic_llm.config)

+ 29 - 4
tests/llm/test_cohere.py

@@ -9,7 +9,7 @@ from embedchain.llm.cohere import CohereLlm
 @pytest.fixture
 @pytest.fixture
 def cohere_llm_config():
 def cohere_llm_config():
     os.environ["COHERE_API_KEY"] = "test_api_key"
     os.environ["COHERE_API_KEY"] = "test_api_key"
-    config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
+    config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False)
     yield config
     yield config
     os.environ.pop("COHERE_API_KEY")
     os.environ.pop("COHERE_API_KEY")
 
 
@@ -36,10 +36,35 @@ def test_get_llm_model_answer(cohere_llm_config, mocker):
     assert answer == "Test answer"
     assert answer == "Test answer"
 
 
 
 
+def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker):
+    test_config = BaseLlmConfig(
+        temperature=cohere_llm_config.temperature,
+        max_tokens=cohere_llm_config.max_tokens,
+        top_p=cohere_llm_config.top_p,
+        model=cohere_llm_config.model,
+        token_usage=True,
+    )
+    mocker.patch(
+        "embedchain.llm.cohere.CohereLlm._get_answer",
+        return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}),
+    )
+
+    llm = CohereLlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 3.5e-06,
+        "cost_currency": "USD",
+    }
+
+
 def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
 def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
-    mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere")
-    mock_instance = mocked_cohere.return_value
-    mock_instance.invoke.return_value = "Mocked answer"
+    mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere")
+    mocked_cohere.return_value.invoke.return_value.content = "Mocked answer"
 
 
     llm = CohereLlm(cohere_llm_config)
     llm = CohereLlm(cohere_llm_config)
     prompt = "Test query"
     prompt = "Test query"

+ 31 - 4
tests/llm/test_mistralai.py

@@ -24,7 +24,7 @@ def test_mistralai_llm_init(monkeypatch):
 
 
 
 
 def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
 def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
-    def mock_get_answer(prompt, config):
+    def mock_get_answer(self, prompt, config):
         return "Generated Text"
         return "Generated Text"
 
 
     monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
     monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
@@ -36,7 +36,7 @@ def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
 
 
 def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
 def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
     mistralai_llm_config.system_prompt = "Test system prompt"
     mistralai_llm_config.system_prompt = "Test system prompt"
-    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
     llm = MistralAILlm(config=mistralai_llm_config)
     llm = MistralAILlm(config=mistralai_llm_config)
     result = llm.get_llm_model_answer("test prompt")
     result = llm.get_llm_model_answer("test prompt")
 
 
@@ -44,7 +44,7 @@ def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_conf
 
 
 
 
 def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
 def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
-    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
     llm = MistralAILlm(config=mistralai_llm_config)
     llm = MistralAILlm(config=mistralai_llm_config)
     result = llm.get_llm_model_answer("")
     result = llm.get_llm_model_answer("")
 
 
@@ -53,8 +53,35 @@ def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
 
 
 def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
 def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
     mistralai_llm_config.system_prompt = None
     mistralai_llm_config.system_prompt = None
-    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+    monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
     llm = MistralAILlm(config=mistralai_llm_config)
     llm = MistralAILlm(config=mistralai_llm_config)
     result = llm.get_llm_model_answer("test prompt")
     result = llm.get_llm_model_answer("test prompt")
 
 
     assert result == "Generated Text"
     assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
+    test_config = BaseLlmConfig(
+        temperature=mistralai_llm_config.temperature,
+        max_tokens=mistralai_llm_config.max_tokens,
+        top_p=mistralai_llm_config.top_p,
+        model=mistralai_llm_config.model,
+        token_usage=True,
+    )
+    monkeypatch.setattr(
+        MistralAILlm,
+        "_get_answer",
+        lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
+    )
+
+    llm = MistralAILlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Generated Text"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 7.5e-07,
+        "cost_currency": "USD",
+    }

+ 29 - 0
tests/llm/test_openai.py

@@ -62,6 +62,35 @@ def test_get_llm_model_answer_empty_prompt(config, mocker):
     mocked_get_answer.assert_called_once_with("", config)
     mocked_get_answer.assert_called_once_with("", config)
 
 
 
 
+def test_get_llm_model_answer_with_token_usage(config, mocker):
+    test_config = BaseLlmConfig(
+        temperature=config.temperature,
+        max_tokens=config.max_tokens,
+        top_p=config.top_p,
+        stream=config.stream,
+        system_prompt=config.system_prompt,
+        model=config.model,
+        token_usage=True,
+    )
+    mocked_get_answer = mocker.patch(
+        "embedchain.llm.openai.OpenAILlm._get_answer",
+        return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
+    )
+
+    llm = OpenAILlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 5.5e-06,
+        "cost_currency": "USD",
+    }
+    mocked_get_answer.assert_called_once_with("Test query", test_config)
+
+
 def test_get_llm_model_answer_with_streaming(config, mocker):
 def test_get_llm_model_answer_with_streaming(config, mocker):
     config.stream = True
     config.stream = True
     mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
     mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")

+ 29 - 3
tests/llm/test_together.py

@@ -9,7 +9,7 @@ from embedchain.llm.together import TogetherLlm
 @pytest.fixture
 @pytest.fixture
 def together_llm_config():
 def together_llm_config():
     os.environ["TOGETHER_API_KEY"] = "test_api_key"
     os.environ["TOGETHER_API_KEY"] = "test_api_key"
-    config = BaseLlmConfig(model="togethercomputer/RedPajama-INCITE-7B-Base", max_tokens=50, temperature=0.7, top_p=0.8)
+    config = BaseLlmConfig(model="together-ai-up-to-3b", max_tokens=50, temperature=0.7, top_p=0.8)
     yield config
     yield config
     os.environ.pop("TOGETHER_API_KEY")
     os.environ.pop("TOGETHER_API_KEY")
 
 
@@ -36,10 +36,36 @@ def test_get_llm_model_answer(together_llm_config, mocker):
     assert answer == "Test answer"
     assert answer == "Test answer"
 
 
 
 
+def test_get_llm_model_answer_with_token_usage(together_llm_config, mocker):
+    test_config = BaseLlmConfig(
+        temperature=together_llm_config.temperature,
+        max_tokens=together_llm_config.max_tokens,
+        top_p=together_llm_config.top_p,
+        model=together_llm_config.model,
+        token_usage=True,
+    )
+    mocker.patch(
+        "embedchain.llm.together.TogetherLlm._get_answer",
+        return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
+    )
+
+    llm = TogetherLlm(test_config)
+    answer, token_info = llm.get_llm_model_answer("Test query")
+
+    assert answer == "Test answer"
+    assert token_info == {
+        "prompt_tokens": 1,
+        "completion_tokens": 2,
+        "total_tokens": 3,
+        "total_cost": 3e-07,
+        "cost_currency": "USD",
+    }
+
+
 def test_get_answer_mocked_together(together_llm_config, mocker):
 def test_get_answer_mocked_together(together_llm_config, mocker):
-    mocked_together = mocker.patch("embedchain.llm.together.Together")
+    mocked_together = mocker.patch("embedchain.llm.together.ChatTogether")
     mock_instance = mocked_together.return_value
     mock_instance = mocked_together.return_value
-    mock_instance.invoke.return_value = "Mocked answer"
+    mock_instance.invoke.return_value.content = "Mocked answer"
 
 
     llm = TogetherLlm(together_llm_config)
     llm = TogetherLlm(together_llm_config)
     prompt = "Test query"
     prompt = "Test query"

+ 26 - 1
tests/llm/test_vertex_ai.py

@@ -24,7 +24,32 @@ def test_get_llm_model_answer(vertexai_llm):
         prompt = "Test Prompt"
         prompt = "Test Prompt"
         response = vertexai_llm.get_llm_model_answer(prompt)
         response = vertexai_llm.get_llm_model_answer(prompt)
         assert response == "Test Response"
         assert response == "Test Response"
-        mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
+        mock_method.assert_called_once_with(prompt, vertexai_llm.config)
+
+
+def test_get_llm_model_answer_with_token_usage(vertexai_llm):
+    test_config = BaseLlmConfig(
+        temperature=vertexai_llm.config.temperature,
+        max_tokens=vertexai_llm.config.max_tokens,
+        top_p=vertexai_llm.config.top_p,
+        model=vertexai_llm.config.model,
+        token_usage=True,
+    )
+    vertexai_llm.config = test_config
+    with patch.object(
+        VertexAILlm,
+        "_get_answer",
+        return_value=("Test Response", {"prompt_token_count": 1, "candidates_token_count": 2}),
+    ):
+        response, token_info = vertexai_llm.get_llm_model_answer("Test Query")
+        assert response == "Test Response"
+        assert token_info == {
+            "prompt_tokens": 1,
+            "completion_tokens": 2,
+            "total_tokens": 3,
+            "total_cost": 3.75e-07,
+            "cost_currency": "USD",
+        }
 
 
 
 
 @patch("embedchain.llm.vertex_ai.ChatVertexAI")
 @patch("embedchain.llm.vertex_ai.ChatVertexAI")

Some files were not shown because too many files changed in this diff