pinecone.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import os
  2. from typing import Dict, List, Optional, Tuple
  3. try:
  4. import pinecone
  5. except ImportError:
  6. raise ImportError(
  7. "Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
  8. ) from None
  9. from embedchain.config.vectordb.pinecone import PineconeDBConfig
  10. from embedchain.helper.json_serializable import register_deserializable
  11. from embedchain.vectordb.base import BaseVectorDB
  12. @register_deserializable
  13. class PineconeDB(BaseVectorDB):
  14. """
  15. Pinecone as vector database
  16. """
  17. BATCH_SIZE = 100
  18. def __init__(
  19. self,
  20. config: Optional[PineconeDBConfig] = None,
  21. ):
  22. """Pinecone as vector database.
  23. :param config: Pinecone database config, defaults to None
  24. :type config: PineconeDBConfig, optional
  25. :raises ValueError: No config provided
  26. """
  27. if config is None:
  28. self.config = PineconeDBConfig()
  29. else:
  30. if not isinstance(config, PineconeDBConfig):
  31. raise TypeError(
  32. "config is not a `PineconeDBConfig` instance. "
  33. "Please make sure the type is right and that you are passing an instance."
  34. )
  35. self.config = config
  36. self.client = self._setup_pinecone_index()
  37. # Call parent init here because embedder is needed
  38. super().__init__(config=self.config)
  39. def _initialize(self):
  40. """
  41. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  42. """
  43. if not self.embedder:
  44. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  45. # Loads the Pinecone index or creates it if not present.
  46. def _setup_pinecone_index(self):
  47. pinecone.init(
  48. api_key=os.environ.get("PINECONE_API_KEY"),
  49. environment=os.environ.get("PINECONE_ENV"),
  50. **self.config.extra_params,
  51. )
  52. self.index_name = self._get_index_name()
  53. indexes = pinecone.list_indexes()
  54. if indexes is None or self.index_name not in indexes:
  55. pinecone.create_index(
  56. name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
  57. )
  58. return pinecone.Index(self.index_name)
  59. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  60. """
  61. Get existing doc ids present in vector database
  62. :param ids: _list of doc ids to check for existence
  63. :type ids: List[str]
  64. :param where: to filter data
  65. :type where: Dict[str, any]
  66. :return: ids
  67. :rtype: Set[str]
  68. """
  69. existing_ids = list()
  70. if ids is not None:
  71. for i in range(0, len(ids), 1000):
  72. result = self.client.fetch(ids=ids[i : i + 1000])
  73. batch_existing_ids = list(result.get("vectors").keys())
  74. existing_ids.extend(batch_existing_ids)
  75. return {"ids": existing_ids}
  76. def add(
  77. self,
  78. embeddings: List[List[float]],
  79. documents: List[str],
  80. metadatas: List[object],
  81. ids: List[str],
  82. skip_embedding: bool,
  83. ):
  84. """add data in vector database
  85. :param documents: list of texts to add
  86. :type documents: List[str]
  87. :param metadatas: list of metadata associated with docs
  88. :type metadatas: List[object]
  89. :param ids: ids of docs
  90. :type ids: List[str]
  91. """
  92. docs = []
  93. print("Adding documents to Pinecone...")
  94. embeddings = self.embedder.embedding_fn(documents)
  95. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  96. docs.append(
  97. {
  98. "id": id,
  99. "values": embedding,
  100. "metadata": {**metadata, "text": text},
  101. }
  102. )
  103. for i in range(0, len(docs), self.BATCH_SIZE):
  104. self.client.upsert(docs[i : i + self.BATCH_SIZE])
  105. def query(
  106. self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
  107. ) -> List[Tuple[str, str, str]]:
  108. """
  109. query contents from vector database based on vector similarity
  110. :param input_query: list of query string
  111. :type input_query: List[str]
  112. :param n_results: no of similar documents to fetch from database
  113. :type n_results: int
  114. :param where: Optional. to filter data
  115. :type where: Dict[str, any]
  116. :param skip_embedding: Optional. if True, input_query is already embedded
  117. :type skip_embedding: bool
  118. :return: The content of the document that matched your query, url of the source, doc_id
  119. :rtype: List[Tuple[str,str,str]]
  120. """
  121. if not skip_embedding:
  122. query_vector = self.embedder.embedding_fn([input_query])[0]
  123. else:
  124. query_vector = input_query
  125. data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
  126. contents = []
  127. for doc in data["matches"]:
  128. metadata = doc["metadata"]
  129. context = metadata["text"]
  130. source = metadata["url"]
  131. doc_id = metadata["doc_id"]
  132. contents.append(tuple((context, source, doc_id)))
  133. return contents
  134. def set_collection_name(self, name: str):
  135. """
  136. Set the name of the collection. A collection is an isolated space for vectors.
  137. :param name: Name of the collection.
  138. :type name: str
  139. """
  140. if not isinstance(name, str):
  141. raise TypeError("Collection name must be a string")
  142. self.config.collection_name = name
  143. def count(self) -> int:
  144. """
  145. Count number of documents/chunks embedded in the database.
  146. :return: number of documents
  147. :rtype: int
  148. """
  149. return self.client.describe_index_stats()["total_vector_count"]
  150. def _get_or_create_db(self):
  151. """Called during initialization"""
  152. return self.client
  153. def reset(self):
  154. """
  155. Resets the database. Deletes all embeddings irreversibly.
  156. """
  157. # Delete all data from the database
  158. pinecone.delete_index(self.index_name)
  159. self._setup_pinecone_index()
  160. # Pinecone only allows alphanumeric characters and "-" in the index name
  161. def _get_index_name(self) -> str:
  162. """Get the Pinecone index for a collection
  163. :return: Pinecone index
  164. :rtype: str
  165. """
  166. return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")