pinecone.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. try:
  5. import pinecone
  6. except ImportError:
  7. raise ImportError(
  8. "Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
  9. ) from None
  10. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  11. from embedchain.helpers.json_serializable import register_deserializable
  12. from embedchain.utils.misc import chunks
  13. from embedchain.vectordb.base import BaseVectorDB
  14. @register_deserializable
  15. class PineconeDB(BaseVectorDB):
  16. """
  17. Pinecone as vector database
  18. """
  19. BATCH_SIZE = 100
  20. def __init__(
  21. self,
  22. config: Optional[PineconeDBConfig] = None,
  23. ):
  24. """Pinecone as vector database.
  25. :param config: Pinecone database config, defaults to None
  26. :type config: PineconeDBConfig, optional
  27. :raises ValueError: No config provided
  28. """
  29. if config is None:
  30. self.config = PineconeDBConfig()
  31. else:
  32. if not isinstance(config, PineconeDBConfig):
  33. raise TypeError(
  34. "config is not a `PineconeDBConfig` instance. "
  35. "Please make sure the type is right and that you are passing an instance."
  36. )
  37. self.config = config
  38. self.client = self._setup_pinecone_index()
  39. # Call parent init here because embedder is needed
  40. super().__init__(config=self.config)
  41. def _initialize(self):
  42. """
  43. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  44. """
  45. if not self.embedder:
  46. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  47. def _setup_pinecone_index(self):
  48. """
  49. Loads the Pinecone index or creates it if not present.
  50. """
  51. pinecone.init(
  52. api_key=os.environ.get("PINECONE_API_KEY"),
  53. environment=os.environ.get("PINECONE_ENV"),
  54. **self.config.extra_params,
  55. )
  56. indexes = pinecone.list_indexes()
  57. if indexes is None or self.config.index_name not in indexes:
  58. pinecone.create_index(
  59. name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
  60. )
  61. return pinecone.Index(self.config.index_name)
  62. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  63. """
  64. Get existing doc ids present in vector database
  65. :param ids: _list of doc ids to check for existence
  66. :type ids: list[str]
  67. :param where: to filter data
  68. :type where: dict[str, any]
  69. :return: ids
  70. :rtype: Set[str]
  71. """
  72. existing_ids = list()
  73. metadatas = []
  74. if ids is not None:
  75. for i in range(0, len(ids), 1000):
  76. result = self.client.fetch(ids=ids[i : i + 1000])
  77. vectors = result.get("vectors")
  78. batch_existing_ids = list(vectors.keys())
  79. existing_ids.extend(batch_existing_ids)
  80. metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
  81. if where is not None:
  82. logging.warning("Filtering is not supported by Pinecone")
  83. return {"ids": existing_ids, "metadatas": metadatas}
  84. def add(
  85. self,
  86. documents: list[str],
  87. metadatas: list[object],
  88. ids: list[str],
  89. **kwargs: Optional[dict[str, any]],
  90. ):
  91. """add data in vector database
  92. :param documents: list of texts to add
  93. :type documents: list[str]
  94. :param metadatas: list of metadata associated with docs
  95. :type metadatas: list[object]
  96. :param ids: ids of docs
  97. :type ids: list[str]
  98. """
  99. docs = []
  100. print("Adding documents to Pinecone...")
  101. embeddings = self.embedder.embedding_fn(documents)
  102. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  103. docs.append(
  104. {
  105. "id": id,
  106. "values": embedding,
  107. "metadata": {**metadata, "text": text},
  108. }
  109. )
  110. for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
  111. self.client.upsert(chunk, **kwargs)
  112. def query(
  113. self,
  114. input_query: list[str],
  115. n_results: int,
  116. where: dict[str, any],
  117. citations: bool = False,
  118. **kwargs: Optional[dict[str, any]],
  119. ) -> Union[list[tuple[str, dict]], list[str]]:
  120. """
  121. query contents from vector database based on vector similarity
  122. :param input_query: list of query string
  123. :type input_query: list[str]
  124. :param n_results: no of similar documents to fetch from database
  125. :type n_results: int
  126. :param where: Optional. to filter data
  127. :type where: dict[str, any]
  128. :param citations: we use citations boolean param to return context along with the answer.
  129. :type citations: bool, default is False.
  130. :return: The content of the document that matched your query,
  131. along with url of the source and doc_id (if citations flag is true)
  132. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  133. """
  134. query_vector = self.embedder.embedding_fn([input_query])[0]
  135. query_filter = self._generate_filter(where)
  136. data = self.client.query(
  137. vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs
  138. )
  139. contexts = []
  140. for doc in data["matches"]:
  141. metadata = doc["metadata"]
  142. context = metadata["text"]
  143. if citations:
  144. metadata["score"] = doc["score"]
  145. contexts.append(tuple((context, metadata)))
  146. else:
  147. contexts.append(context)
  148. return contexts
  149. def set_collection_name(self, name: str):
  150. """
  151. Set the name of the collection. A collection is an isolated space for vectors.
  152. :param name: Name of the collection.
  153. :type name: str
  154. """
  155. if not isinstance(name, str):
  156. raise TypeError("Collection name must be a string")
  157. self.config.collection_name = name
  158. def count(self) -> int:
  159. """
  160. Count number of documents/chunks embedded in the database.
  161. :return: number of documents
  162. :rtype: int
  163. """
  164. return self.client.describe_index_stats()["total_vector_count"]
  165. def _get_or_create_db(self):
  166. """Called during initialization"""
  167. return self.client
  168. def reset(self):
  169. """
  170. Resets the database. Deletes all embeddings irreversibly.
  171. """
  172. # Delete all data from the database
  173. pinecone.delete_index(self.config.index_name)
  174. self._setup_pinecone_index()
  175. @staticmethod
  176. def _generate_filter(where: dict):
  177. query = {}
  178. for k, v in where.items():
  179. query[k] = {"$eq": v}
  180. return query
  181. def delete(self, where: dict):
  182. """Delete from database.
  183. :param ids: list of ids to delete
  184. :type ids: list[str]
  185. """
  186. # Deleting with filters is not supported for `starter` index type.
  187. # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
  188. db_filter = self._generate_filter(where)
  189. try:
  190. self.client.delete(filter=db_filter)
  191. except Exception as e:
  192. print(f"Failed to delete from Pinecone: {e}")
  193. return