pinecone.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import os
  2. from typing import Optional, Union
  3. try:
  4. import pinecone
  5. except ImportError:
  6. raise ImportError(
  7. "Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
  8. ) from None
  9. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  10. from embedchain.helpers.json_serializable import register_deserializable
  11. from embedchain.utils.misc import chunks
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class PineconeDB(BaseVectorDB):
  15. """
  16. Pinecone as vector database
  17. """
  18. BATCH_SIZE = 100
  19. def __init__(
  20. self,
  21. config: Optional[PineconeDBConfig] = None,
  22. ):
  23. """Pinecone as vector database.
  24. :param config: Pinecone database config, defaults to None
  25. :type config: PineconeDBConfig, optional
  26. :raises ValueError: No config provided
  27. """
  28. if config is None:
  29. self.config = PineconeDBConfig()
  30. else:
  31. if not isinstance(config, PineconeDBConfig):
  32. raise TypeError(
  33. "config is not a `PineconeDBConfig` instance. "
  34. "Please make sure the type is right and that you are passing an instance."
  35. )
  36. self.config = config
  37. self._setup_pinecone_index()
  38. # Call parent init here because embedder is needed
  39. super().__init__(config=self.config)
  40. def _initialize(self):
  41. """
  42. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  43. """
  44. if not self.embedder:
  45. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  46. def _setup_pinecone_index(self):
  47. """
  48. Loads the Pinecone index or creates it if not present.
  49. """
  50. api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
  51. if not api_key:
  52. raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
  53. self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
  54. indexes = self.client.list_indexes().names()
  55. if indexes is None or self.config.index_name not in indexes:
  56. if self.config.pod_config:
  57. spec = pinecone.PodSpec(**self.config.pod_config)
  58. elif self.config.serverless_config:
  59. spec = pinecone.ServerlessSpec(**self.config.serverless_config)
  60. else:
  61. raise ValueError("No pod_config or serverless_config found.")
  62. self.client.create_index(
  63. name=self.config.index_name,
  64. metric=self.config.metric,
  65. dimension=self.config.vector_dimension,
  66. spec=spec,
  67. )
  68. self.pinecone_index = self.client.Index(self.config.index_name)
  69. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  70. """
  71. Get existing doc ids present in vector database
  72. :param ids: _list of doc ids to check for existence
  73. :type ids: list[str]
  74. :param where: to filter data
  75. :type where: dict[str, any]
  76. :return: ids
  77. :rtype: Set[str]
  78. """
  79. existing_ids = list()
  80. metadatas = []
  81. if ids is not None:
  82. for i in range(0, len(ids), 1000):
  83. result = self.pinecone_index.fetch(ids=ids[i : i + 1000])
  84. vectors = result.get("vectors")
  85. batch_existing_ids = list(vectors.keys())
  86. existing_ids.extend(batch_existing_ids)
  87. metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
  88. return {"ids": existing_ids, "metadatas": metadatas}
  89. def add(
  90. self,
  91. documents: list[str],
  92. metadatas: list[object],
  93. ids: list[str],
  94. **kwargs: Optional[dict[str, any]],
  95. ):
  96. """add data in vector database
  97. :param documents: list of texts to add
  98. :type documents: list[str]
  99. :param metadatas: list of metadata associated with docs
  100. :type metadatas: list[object]
  101. :param ids: ids of docs
  102. :type ids: list[str]
  103. """
  104. docs = []
  105. embeddings = self.embedder.embedding_fn(documents)
  106. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  107. docs.append(
  108. {
  109. "id": id,
  110. "values": embedding,
  111. "metadata": {**metadata, "text": text},
  112. }
  113. )
  114. for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
  115. self.pinecone_index.upsert(chunk, **kwargs)
  116. def query(
  117. self,
  118. input_query: list[str],
  119. n_results: int,
  120. where: Optional[dict[str, any]] = None,
  121. raw_filter: Optional[dict[str, any]] = None,
  122. citations: bool = False,
  123. app_id: Optional[str] = None,
  124. **kwargs: Optional[dict[str, any]],
  125. ) -> Union[list[tuple[str, dict]], list[str]]:
  126. """
  127. Query contents from vector database based on vector similarity.
  128. Args:
  129. input_query (list[str]): List of query strings.
  130. n_results (int): Number of similar documents to fetch from the database.
  131. where (dict[str, any], optional): Filter criteria for the search.
  132. raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
  133. citations (bool, optional): Flag to return context along with metadata. Defaults to False.
  134. app_id (str, optional): Application ID to be passed to Pinecone.
  135. Returns:
  136. Union[list[tuple[str, dict]], list[str]]: List of document contexts, optionally with metadata.
  137. """
  138. query_filter = raw_filter if raw_filter is not None else self._generate_filter(where)
  139. if app_id:
  140. query_filter["app_id"] = {"$eq": app_id}
  141. query_vector = self.embedder.embedding_fn([input_query])[0]
  142. data = self.pinecone_index.query(
  143. vector=query_vector,
  144. filter=query_filter,
  145. top_k=n_results,
  146. include_metadata=True,
  147. **kwargs,
  148. )
  149. return [
  150. (metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text")
  151. for doc in data.get("matches", [])
  152. for metadata in [doc.get("metadata", {})]
  153. ]
  154. def set_collection_name(self, name: str):
  155. """
  156. Set the name of the collection. A collection is an isolated space for vectors.
  157. :param name: Name of the collection.
  158. :type name: str
  159. """
  160. if not isinstance(name, str):
  161. raise TypeError("Collection name must be a string")
  162. self.config.collection_name = name
  163. def count(self) -> int:
  164. """
  165. Count number of documents/chunks embedded in the database.
  166. :return: number of documents
  167. :rtype: int
  168. """
  169. data = self.pinecone_index.describe_index_stats()
  170. return data["total_vector_count"]
  171. def _get_or_create_db(self):
  172. """Called during initialization"""
  173. return self.client
  174. def reset(self):
  175. """
  176. Resets the database. Deletes all embeddings irreversibly.
  177. """
  178. # Delete all data from the database
  179. self.client.delete_index(self.config.index_name)
  180. self._setup_pinecone_index()
  181. @staticmethod
  182. def _generate_filter(where: dict):
  183. query = {}
  184. for k, v in where.items():
  185. query[k] = {"$eq": v}
  186. return query
  187. def delete(self, where: dict):
  188. """Delete from database.
  189. :param ids: list of ids to delete
  190. :type ids: list[str]
  191. """
  192. # Deleting with filters is not supported for `starter` index type.
  193. # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
  194. db_filter = self._generate_filter(where)
  195. try:
  196. self.pinecone_index.delete(filter=db_filter)
  197. except Exception as e:
  198. print(f"Failed to delete from Pinecone: {e}")
  199. return