chroma_embeddings.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """
  2. Note that this file is copied from Chroma repository. We will remove this file once the fix in
  3. ChromaDB's repository.
  4. """
  5. from typing import Optional
  6. from chromadb.api.types import Documents, Embeddings
  7. class OpenAIEmbeddingFunction:
  8. def __init__(
  9. self,
  10. api_key: Optional[str] = None,
  11. model_name: str = "text-embedding-ada-002",
  12. organization_id: Optional[str] = None,
  13. api_base: Optional[str] = None,
  14. api_type: Optional[str] = None,
  15. api_version: Optional[str] = None,
  16. deployment_id: Optional[str] = None,
  17. ):
  18. """
  19. Initialize the OpenAIEmbeddingFunction.
  20. Args:
  21. api_key (str, optional): Your API key for the OpenAI API. If not
  22. provided, it will raise an error to provide an OpenAI API key.
  23. organization_id(str, optional): The OpenAI organization ID if applicable
  24. model_name (str, optional): The name of the model to use for text
  25. embeddings. Defaults to "text-embedding-ada-002".
  26. api_base (str, optional): The base path for the API. If not provided,
  27. it will use the base path for the OpenAI API. This can be used to
  28. point to a different deployment, such as an Azure deployment.
  29. api_type (str, optional): The type of the API deployment. This can be
  30. used to specify a different deployment, such as 'azure'. If not
  31. provided, it will use the default OpenAI deployment.
  32. api_version (str, optional): The api version for the API. If not provided,
  33. it will use the api version for the OpenAI API. This can be used to
  34. point to a different deployment, such as an Azure deployment.
  35. deployment_id (str, optional): Deployment ID for Azure OpenAI.
  36. """
  37. try:
  38. import openai
  39. except ImportError:
  40. raise ValueError("The openai python package is not installed. Please install it with `pip install openai`")
  41. if api_key is not None:
  42. openai.api_key = api_key
  43. # If the api key is still not set, raise an error
  44. elif openai.api_key is None:
  45. raise ValueError(
  46. "Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys"
  47. )
  48. if api_base is not None:
  49. openai.api_base = api_base
  50. if api_version is not None:
  51. openai.api_version = api_version
  52. self._api_type = api_type
  53. if api_type is not None:
  54. openai.api_type = api_type
  55. if organization_id is not None:
  56. openai.organization = organization_id
  57. self._v1 = openai.__version__.startswith("1.")
  58. if self._v1:
  59. if api_type == "azure":
  60. self._client = openai.AzureOpenAI(
  61. api_key=api_key, api_version=api_version, azure_endpoint=api_base
  62. ).embeddings
  63. else:
  64. self._client = openai.OpenAI(api_key=api_key, base_url=api_base).embeddings
  65. else:
  66. self._client = openai.Embedding
  67. self._model_name = model_name
  68. self._deployment_id = deployment_id
  69. def __call__(self, input: Documents) -> Embeddings:
  70. # replace newlines, which can negatively affect performance.
  71. input = [t.replace("\n", " ") for t in input]
  72. # Call the OpenAI Embedding API
  73. if self._v1:
  74. embeddings = self._client.create(input=input, model=self._deployment_id or self._model_name).data
  75. # Sort resulting embeddings by index
  76. sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore
  77. # Return just the embeddings
  78. return [result.embedding for result in sorted_embeddings]
  79. else:
  80. if self._api_type == "azure":
  81. embeddings = self._client.create(input=input, engine=self._deployment_id or self._model_name)["data"]
  82. else:
  83. embeddings = self._client.create(input=input, model=self._model_name)["data"]
  84. # Sort resulting embeddings by index
  85. sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
  86. # Return just the embeddings
  87. return [result["embedding"] for result in sorted_embeddings]