pinecone.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. try:
  5. import pinecone
  6. except ImportError:
  7. raise ImportError(
  8. "Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
  9. ) from None
  10. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  11. from embedchain.helpers.json_serializable import register_deserializable
  12. from embedchain.utils.misc import chunks
  13. from embedchain.vectordb.base import BaseVectorDB
  14. @register_deserializable
  15. class PineconeDB(BaseVectorDB):
  16. """
  17. Pinecone as vector database
  18. """
  19. BATCH_SIZE = 100
  20. def __init__(
  21. self,
  22. config: Optional[PineconeDBConfig] = None,
  23. ):
  24. """Pinecone as vector database.
  25. :param config: Pinecone database config, defaults to None
  26. :type config: PineconeDBConfig, optional
  27. :raises ValueError: No config provided
  28. """
  29. if config is None:
  30. self.config = PineconeDBConfig()
  31. else:
  32. if not isinstance(config, PineconeDBConfig):
  33. raise TypeError(
  34. "config is not a `PineconeDBConfig` instance. "
  35. "Please make sure the type is right and that you are passing an instance."
  36. )
  37. self.config = config
  38. self._setup_pinecone_index()
  39. # Call parent init here because embedder is needed
  40. super().__init__(config=self.config)
  41. def _initialize(self):
  42. """
  43. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  44. """
  45. if not self.embedder:
  46. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  47. def _setup_pinecone_index(self):
  48. """
  49. Loads the Pinecone index or creates it if not present.
  50. """
  51. api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
  52. if not api_key:
  53. raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
  54. self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
  55. indexes = self.client.list_indexes().names()
  56. if indexes is None or self.config.index_name not in indexes:
  57. if self.config.pod_config:
  58. spec = pinecone.PodSpec(**self.config.pod_config)
  59. elif self.config.serverless_config:
  60. spec = pinecone.ServerlessSpec(**self.config.serverless_config)
  61. else:
  62. raise ValueError("No pod_config or serverless_config found.")
  63. self.client.create_index(
  64. name=self.config.index_name,
  65. metric=self.config.metric,
  66. dimension=self.config.vector_dimension,
  67. spec=spec,
  68. )
  69. self.pinecone_index = self.client.Index(self.config.index_name)
  70. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  71. """
  72. Get existing doc ids present in vector database
  73. :param ids: _list of doc ids to check for existence
  74. :type ids: list[str]
  75. :param where: to filter data
  76. :type where: dict[str, any]
  77. :return: ids
  78. :rtype: Set[str]
  79. """
  80. existing_ids = list()
  81. metadatas = []
  82. if ids is not None:
  83. for i in range(0, len(ids), 1000):
  84. result = self.pinecone_index.fetch(ids=ids[i : i + 1000])
  85. vectors = result.get("vectors")
  86. batch_existing_ids = list(vectors.keys())
  87. existing_ids.extend(batch_existing_ids)
  88. metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
  89. if where is not None:
  90. logging.warning("Filtering is not supported by Pinecone")
  91. return {"ids": existing_ids, "metadatas": metadatas}
  92. def add(
  93. self,
  94. documents: list[str],
  95. metadatas: list[object],
  96. ids: list[str],
  97. **kwargs: Optional[dict[str, any]],
  98. ):
  99. """add data in vector database
  100. :param documents: list of texts to add
  101. :type documents: list[str]
  102. :param metadatas: list of metadata associated with docs
  103. :type metadatas: list[object]
  104. :param ids: ids of docs
  105. :type ids: list[str]
  106. """
  107. docs = []
  108. print("Adding documents to Pinecone...")
  109. embeddings = self.embedder.embedding_fn(documents)
  110. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  111. docs.append(
  112. {
  113. "id": id,
  114. "values": embedding,
  115. "metadata": {**metadata, "text": text},
  116. }
  117. )
  118. for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
  119. self.pinecone_index.upsert(chunk, **kwargs)
  120. def query(
  121. self,
  122. input_query: list[str],
  123. n_results: int,
  124. where: dict[str, any],
  125. citations: bool = False,
  126. **kwargs: Optional[dict[str, any]],
  127. ) -> Union[list[tuple[str, dict]], list[str]]:
  128. """
  129. query contents from vector database based on vector similarity
  130. :param input_query: list of query string
  131. :type input_query: list[str]
  132. :param n_results: no of similar documents to fetch from database
  133. :type n_results: int
  134. :param where: Optional. to filter data
  135. :type where: dict[str, any]
  136. :param citations: we use citations boolean param to return context along with the answer.
  137. :type citations: bool, default is False.
  138. :return: The content of the document that matched your query,
  139. along with url of the source and doc_id (if citations flag is true)
  140. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  141. """
  142. query_vector = self.embedder.embedding_fn([input_query])[0]
  143. query_filter = self._generate_filter(where)
  144. data = self.pinecone_index.query(
  145. vector=query_vector,
  146. filter=query_filter,
  147. top_k=n_results,
  148. include_metadata=True,
  149. **kwargs,
  150. )
  151. contexts = []
  152. for doc in data.get("matches", []):
  153. metadata = doc.get("metadata", {})
  154. context = metadata.get("text")
  155. if citations:
  156. metadata["score"] = doc.get("score")
  157. contexts.append(tuple((context, metadata)))
  158. else:
  159. contexts.append(context)
  160. return contexts
  161. def set_collection_name(self, name: str):
  162. """
  163. Set the name of the collection. A collection is an isolated space for vectors.
  164. :param name: Name of the collection.
  165. :type name: str
  166. """
  167. if not isinstance(name, str):
  168. raise TypeError("Collection name must be a string")
  169. self.config.collection_name = name
  170. def count(self) -> int:
  171. """
  172. Count number of documents/chunks embedded in the database.
  173. :return: number of documents
  174. :rtype: int
  175. """
  176. data = self.pinecone_index.describe_index_stats()
  177. return data["total_vector_count"]
  178. def _get_or_create_db(self):
  179. """Called during initialization"""
  180. return self.client
  181. def reset(self):
  182. """
  183. Resets the database. Deletes all embeddings irreversibly.
  184. """
  185. # Delete all data from the database
  186. self.client.delete_index(self.config.index_name)
  187. self._setup_pinecone_index()
  188. @staticmethod
  189. def _generate_filter(where: dict):
  190. query = {}
  191. for k, v in where.items():
  192. query[k] = {"$eq": v}
  193. return query
  194. def delete(self, where: dict):
  195. """Delete from database.
  196. :param ids: list of ids to delete
  197. :type ids: list[str]
  198. """
  199. # Deleting with filters is not supported for `starter` index type.
  200. # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
  201. db_filter = self._generate_filter(where)
  202. try:
  203. self.pinecone_index.delete(filter=db_filter)
  204. except Exception as e:
  205. print(f"Failed to delete from Pinecone: {e}")
  206. return