pinecone.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import logging
  2. import os
  3. from typing import Optional, Union
  4. try:
  5. import pinecone
  6. except ImportError:
  7. raise ImportError(
  8. "Pinecone requires extra dependencies. Install with `pip install pinecone-text pinecone-client`"
  9. ) from None
  10. from pinecone_text.sparse import BM25Encoder
  11. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  12. from embedchain.helpers.json_serializable import register_deserializable
  13. from embedchain.utils.misc import chunks
  14. from embedchain.vectordb.base import BaseVectorDB
  15. logger = logging.getLogger(__name__)
  16. @register_deserializable
  17. class PineconeDB(BaseVectorDB):
  18. """
  19. Pinecone as vector database
  20. """
  21. BATCH_SIZE = 100
  22. def __init__(
  23. self,
  24. config: Optional[PineconeDBConfig] = None,
  25. ):
  26. """Pinecone as vector database.
  27. :param config: Pinecone database config, defaults to None
  28. :type config: PineconeDBConfig, optional
  29. :raises ValueError: No config provided
  30. """
  31. if config is None:
  32. self.config = PineconeDBConfig()
  33. else:
  34. if not isinstance(config, PineconeDBConfig):
  35. raise TypeError(
  36. "config is not a `PineconeDBConfig` instance. "
  37. "Please make sure the type is right and that you are passing an instance."
  38. )
  39. self.config = config
  40. self._setup_pinecone_index()
  41. # Setup BM25Encoder if sparse vectors are to be used
  42. self.bm25_encoder = None
  43. if self.config.hybrid_search:
  44. logger.info("Initializing BM25Encoder for sparse vectors..")
  45. self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
  46. # Call parent init here because embedder is needed
  47. super().__init__(config=self.config)
  48. def _initialize(self):
  49. """
  50. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  51. """
  52. if not self.embedder:
  53. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  54. def _setup_pinecone_index(self):
  55. """
  56. Loads the Pinecone index or creates it if not present.
  57. """
  58. api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
  59. if not api_key:
  60. raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
  61. self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
  62. indexes = self.client.list_indexes().names()
  63. if indexes is None or self.config.index_name not in indexes:
  64. if self.config.pod_config:
  65. spec = pinecone.PodSpec(**self.config.pod_config)
  66. elif self.config.serverless_config:
  67. spec = pinecone.ServerlessSpec(**self.config.serverless_config)
  68. else:
  69. raise ValueError("No pod_config or serverless_config found.")
  70. self.client.create_index(
  71. name=self.config.index_name,
  72. metric=self.config.metric,
  73. dimension=self.config.vector_dimension,
  74. spec=spec,
  75. )
  76. self.pinecone_index = self.client.Index(self.config.index_name)
  77. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  78. """
  79. Get existing doc ids present in vector database
  80. :param ids: _list of doc ids to check for existence
  81. :type ids: list[str]
  82. :param where: to filter data
  83. :type where: dict[str, any]
  84. :return: ids
  85. :rtype: Set[str]
  86. """
  87. existing_ids = list()
  88. metadatas = []
  89. batch_size = 100
  90. if ids is not None:
  91. for i in range(0, len(ids), batch_size):
  92. result = self.pinecone_index.fetch(ids=ids[i : i + batch_size])
  93. vectors = result.get("vectors")
  94. batch_existing_ids = list(vectors.keys())
  95. existing_ids.extend(batch_existing_ids)
  96. metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
  97. return {"ids": existing_ids, "metadatas": metadatas}
  98. def add(
  99. self,
  100. documents: list[str],
  101. metadatas: list[object],
  102. ids: list[str],
  103. **kwargs: Optional[dict[str, any]],
  104. ):
  105. """add data in vector database
  106. :param documents: list of texts to add
  107. :type documents: list[str]
  108. :param metadatas: list of metadata associated with docs
  109. :type metadatas: list[object]
  110. :param ids: ids of docs
  111. :type ids: list[str]
  112. """
  113. docs = []
  114. embeddings = self.embedder.embedding_fn(documents)
  115. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  116. # Insert sparse vectors as well if the user wants to do the hybrid search
  117. sparse_vector_dict = (
  118. {"sparse_values": self.bm25_encoder.encode_documents(text)} if self.bm25_encoder else {}
  119. )
  120. docs.append(
  121. {
  122. "id": id,
  123. "values": embedding,
  124. "metadata": {**metadata, "text": text},
  125. **sparse_vector_dict,
  126. },
  127. )
  128. for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
  129. self.pinecone_index.upsert(chunk, **kwargs)
  130. def query(
  131. self,
  132. input_query: str,
  133. n_results: int,
  134. where: Optional[dict[str, any]] = None,
  135. raw_filter: Optional[dict[str, any]] = None,
  136. citations: bool = False,
  137. app_id: Optional[str] = None,
  138. **kwargs: Optional[dict[str, any]],
  139. ) -> Union[list[tuple[str, dict]], list[str]]:
  140. """
  141. Query contents from vector database based on vector similarity.
  142. Args:
  143. input_query (str): query string.
  144. n_results (int): Number of similar documents to fetch from the database.
  145. where (dict[str, any], optional): Filter criteria for the search.
  146. raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
  147. citations (bool, optional): Flag to return context along with metadata. Defaults to False.
  148. app_id (str, optional): Application ID to be passed to Pinecone.
  149. Returns:
  150. Union[list[tuple[str, dict]], list[str]]: List of document contexts, optionally with metadata.
  151. """
  152. query_filter = raw_filter if raw_filter is not None else self._generate_filter(where)
  153. if app_id:
  154. query_filter["app_id"] = {"$eq": app_id}
  155. query_vector = self.embedder.embedding_fn([input_query])[0]
  156. params = {
  157. "vector": query_vector,
  158. "filter": query_filter,
  159. "top_k": n_results,
  160. "include_metadata": True,
  161. **kwargs,
  162. }
  163. if self.bm25_encoder:
  164. sparse_query_vector = self.bm25_encoder.encode_queries(input_query)
  165. params["sparse_vector"] = sparse_query_vector
  166. data = self.pinecone_index.query(**params)
  167. return [
  168. (metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text")
  169. for doc in data.get("matches", [])
  170. for metadata in [doc.get("metadata", {})]
  171. ]
  172. def set_collection_name(self, name: str):
  173. """
  174. Set the name of the collection. A collection is an isolated space for vectors.
  175. :param name: Name of the collection.
  176. :type name: str
  177. """
  178. if not isinstance(name, str):
  179. raise TypeError("Collection name must be a string")
  180. self.config.collection_name = name
  181. def count(self) -> int:
  182. """
  183. Count number of documents/chunks embedded in the database.
  184. :return: number of documents
  185. :rtype: int
  186. """
  187. data = self.pinecone_index.describe_index_stats()
  188. return data["total_vector_count"]
  189. def _get_or_create_db(self):
  190. """Called during initialization"""
  191. return self.client
  192. def reset(self):
  193. """
  194. Resets the database. Deletes all embeddings irreversibly.
  195. """
  196. # Delete all data from the database
  197. self.client.delete_index(self.config.index_name)
  198. self._setup_pinecone_index()
  199. @staticmethod
  200. def _generate_filter(where: dict):
  201. query = {}
  202. if where is None:
  203. return query
  204. for k, v in where.items():
  205. query[k] = {"$eq": v}
  206. return query
  207. def delete(self, where: dict):
  208. """Delete from database.
  209. :param ids: list of ids to delete
  210. :type ids: list[str]
  211. """
  212. # Deleting with filters is not supported for `starter` index type.
  213. # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
  214. db_filter = self._generate_filter(where)
  215. try:
  216. self.pinecone_index.delete(filter=db_filter)
  217. except Exception as e:
  218. print(f"Failed to delete from Pinecone: {e}")
  219. return