pinecone.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. try:
  5. import pinecone
  6. except ImportError:
  7. raise ImportError(
  8. "Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
  9. ) from None
  10. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  11. from embedchain.helpers.json_serializable import register_deserializable
  12. from embedchain.utils.misc import chunks
  13. from embedchain.vectordb.base import BaseVectorDB
  14. @register_deserializable
  15. class PineconeDB(BaseVectorDB):
  16. """
  17. Pinecone as vector database
  18. """
  19. BATCH_SIZE = 100
  20. def __init__(
  21. self,
  22. config: Optional[PineconeDBConfig] = None,
  23. ):
  24. """Pinecone as vector database.
  25. :param config: Pinecone database config, defaults to None
  26. :type config: PineconeDBConfig, optional
  27. :raises ValueError: No config provided
  28. """
  29. if config is None:
  30. self.config = PineconeDBConfig()
  31. else:
  32. if not isinstance(config, PineconeDBConfig):
  33. raise TypeError(
  34. "config is not a `PineconeDBConfig` instance. "
  35. "Please make sure the type is right and that you are passing an instance."
  36. )
  37. self.config = config
  38. self.client = self._setup_pinecone_index()
  39. # Call parent init here because embedder is needed
  40. super().__init__(config=self.config)
  41. def _initialize(self):
  42. """
  43. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  44. """
  45. if not self.embedder:
  46. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  47. # Loads the Pinecone index or creates it if not present.
  48. def _setup_pinecone_index(self):
  49. pinecone.init(
  50. api_key=os.environ.get("PINECONE_API_KEY"),
  51. environment=os.environ.get("PINECONE_ENV"),
  52. **self.config.extra_params,
  53. )
  54. self.index_name = self._get_index_name()
  55. indexes = pinecone.list_indexes()
  56. if indexes is None or self.index_name not in indexes:
  57. pinecone.create_index(
  58. name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
  59. )
  60. return pinecone.Index(self.index_name)
  61. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  62. """
  63. Get existing doc ids present in vector database
  64. :param ids: _list of doc ids to check for existence
  65. :type ids: list[str]
  66. :param where: to filter data
  67. :type where: dict[str, any]
  68. :return: ids
  69. :rtype: Set[str]
  70. """
  71. existing_ids = list()
  72. metadatas = []
  73. if ids is not None:
  74. for i in range(0, len(ids), 1000):
  75. result = self.client.fetch(ids=ids[i : i + 1000])
  76. vectors = result.get("vectors")
  77. batch_existing_ids = list(vectors.keys())
  78. existing_ids.extend(batch_existing_ids)
  79. metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
  80. if where is not None:
  81. logging.warning("Filtering is not supported by Pinecone")
  82. return {"ids": existing_ids, "metadatas": metadatas}
  83. def add(
  84. self,
  85. documents: list[str],
  86. metadatas: list[object],
  87. ids: list[str],
  88. **kwargs: Optional[dict[str, any]],
  89. ):
  90. """add data in vector database
  91. :param documents: list of texts to add
  92. :type documents: list[str]
  93. :param metadatas: list of metadata associated with docs
  94. :type metadatas: list[object]
  95. :param ids: ids of docs
  96. :type ids: list[str]
  97. """
  98. docs = []
  99. print("Adding documents to Pinecone...")
  100. embeddings = self.embedder.embedding_fn(documents)
  101. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  102. docs.append(
  103. {
  104. "id": id,
  105. "values": embedding,
  106. "metadata": {**metadata, "text": text},
  107. }
  108. )
  109. for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
  110. self.client.upsert(chunk, **kwargs)
  111. def query(
  112. self,
  113. input_query: list[str],
  114. n_results: int,
  115. where: dict[str, any],
  116. citations: bool = False,
  117. **kwargs: Optional[dict[str, any]],
  118. ) -> Union[list[tuple[str, dict]], list[str]]:
  119. """
  120. query contents from vector database based on vector similarity
  121. :param input_query: list of query string
  122. :type input_query: list[str]
  123. :param n_results: no of similar documents to fetch from database
  124. :type n_results: int
  125. :param where: Optional. to filter data
  126. :type where: dict[str, any]
  127. :param citations: we use citations boolean param to return context along with the answer.
  128. :type citations: bool, default is False.
  129. :return: The content of the document that matched your query,
  130. along with url of the source and doc_id (if citations flag is true)
  131. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  132. """
  133. query_vector = self.embedder.embedding_fn([input_query])[0]
  134. query_filter = self._generate_filter(where)
  135. data = self.client.query(
  136. vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs
  137. )
  138. contexts = []
  139. for doc in data["matches"]:
  140. metadata = doc["metadata"]
  141. context = metadata["text"]
  142. if citations:
  143. metadata["score"] = doc["score"]
  144. contexts.append(tuple((context, metadata)))
  145. else:
  146. contexts.append(context)
  147. return contexts
  148. def set_collection_name(self, name: str):
  149. """
  150. Set the name of the collection. A collection is an isolated space for vectors.
  151. :param name: Name of the collection.
  152. :type name: str
  153. """
  154. if not isinstance(name, str):
  155. raise TypeError("Collection name must be a string")
  156. self.config.collection_name = name
  157. def count(self) -> int:
  158. """
  159. Count number of documents/chunks embedded in the database.
  160. :return: number of documents
  161. :rtype: int
  162. """
  163. return self.client.describe_index_stats()["total_vector_count"]
  164. def _get_or_create_db(self):
  165. """Called during initialization"""
  166. return self.client
  167. def reset(self):
  168. """
  169. Resets the database. Deletes all embeddings irreversibly.
  170. """
  171. # Delete all data from the database
  172. pinecone.delete_index(self.index_name)
  173. self._setup_pinecone_index()
  174. # Pinecone only allows alphanumeric characters and "-" in the index name
  175. def _get_index_name(self) -> str:
  176. """Get the Pinecone index for a collection
  177. :return: Pinecone index
  178. :rtype: str
  179. """
  180. return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
  181. @staticmethod
  182. def _generate_filter(where: dict):
  183. query = {}
  184. for k, v in where.items():
  185. query[k] = {"$eq": v}
  186. return query
  187. def delete(self, where: dict):
  188. """Delete from database.
  189. :param ids: list of ids to delete
  190. :type ids: list[str]
  191. """
  192. # Deleting with filters is not supported for `starter` index type.
  193. # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
  194. db_filter = self._generate_filter(where)
  195. try:
  196. self.client.delete(filter=db_filter)
  197. except Exception as e:
  198. print(f"Failed to delete from Pinecone: {e}")
  199. return