openai.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import json
  2. from typing import Dict, List, Optional
  3. from openai import OpenAI
  4. from mem0.llms.base import LLMBase
  5. from mem0.configs.llms.base import BaseLlmConfig
  6. class OpenAILLM(LLMBase):
  7. def __init__(self, config: Optional[BaseLlmConfig] = None):
  8. super().__init__(config)
  9. if not self.config.model:
  10. self.config.model="gpt-4o"
  11. self.client = OpenAI()
  12. def _parse_response(self, response, tools):
  13. """
  14. Process the response based on whether tools are used or not.
  15. Args:
  16. response: The raw response from API.
  17. tools: The list of tools provided in the request.
  18. Returns:
  19. str or dict: The processed response.
  20. """
  21. if tools:
  22. processed_response = {
  23. "content": response.choices[0].message.content,
  24. "tool_calls": []
  25. }
  26. if response.choices[0].message.tool_calls:
  27. for tool_call in response.choices[0].message.tool_calls:
  28. processed_response["tool_calls"].append({
  29. "name": tool_call.function.name,
  30. "arguments": json.loads(tool_call.function.arguments)
  31. })
  32. return processed_response
  33. else:
  34. return response.choices[0].message.content
  35. def generate_response(
  36. self,
  37. messages: List[Dict[str, str]],
  38. response_format=None,
  39. tools: Optional[List[Dict]] = None,
  40. tool_choice: str = "auto",
  41. ):
  42. """
  43. Generate a response based on the given messages using OpenAI.
  44. Args:
  45. messages (list): List of message dicts containing 'role' and 'content'.
  46. response_format (str or object, optional): Format of the response. Defaults to "text".
  47. tools (list, optional): List of tools that the model can call. Defaults to None.
  48. tool_choice (str, optional): Tool choice method. Defaults to "auto".
  49. Returns:
  50. str: The generated response.
  51. """
  52. params = {
  53. "model": self.config.model,
  54. "messages": messages,
  55. "temperature": self.config.temperature,
  56. "max_tokens": self.config.max_tokens,
  57. "top_p": self.config.top_p
  58. }
  59. if response_format:
  60. params["response_format"] = response_format
  61. if tools:
  62. params["tools"] = tools
  63. params["tool_choice"] = tool_choice
  64. response = self.client.chat.completions.create(**params)
  65. return self._parse_response(response, tools)