pineconedb.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import copy
  2. import os
  3. from typing import Dict, List, Optional
  4. try:
  5. import pinecone
  6. except ImportError:
  7. raise ImportError(
  8. "Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
  9. ) from None
  10. from embedchain.config.vectordb.pinecone import PineconeDbConfig
  11. from embedchain.helper.json_serializable import register_deserializable
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class PineconeDb(BaseVectorDB):
  15. BATCH_SIZE = 100
  16. """
  17. Pinecone as vector database
  18. """
  19. def __init__(
  20. self,
  21. config: Optional[PineconeDbConfig] = None,
  22. ):
  23. """Pinecone as vector database.
  24. :param config: Pinecone database config, defaults to None
  25. :type config: PineconeDbConfig, optional
  26. :raises ValueError: No config provided
  27. """
  28. if config is None:
  29. self.config = PineconeDbConfig()
  30. else:
  31. if not isinstance(config, PineconeDbConfig):
  32. raise TypeError(
  33. "config is not a `PineconeDbConfig` instance. "
  34. "Please make sure the type is right and that you are passing an instance."
  35. )
  36. self.config = config
  37. self.client = self._setup_pinecone_index()
  38. # Call parent init here because embedder is needed
  39. super().__init__(config=self.config)
  40. def _initialize(self):
  41. """
  42. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  43. """
  44. if not self.embedder:
  45. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  46. # Loads the Pinecone index or creates it if not present.
  47. def _setup_pinecone_index(self):
  48. pinecone.init(
  49. api_key=os.environ.get("PINECONE_API_KEY"),
  50. environment=os.environ.get("PINECONE_ENV"),
  51. )
  52. self.index_name = self._get_index_name()
  53. indexes = pinecone.list_indexes()
  54. if indexes is None or self.index_name not in indexes:
  55. pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
  56. return pinecone.Index(self.index_name)
  57. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  58. """
  59. Get existing doc ids present in vector database
  60. :param ids: _list of doc ids to check for existence
  61. :type ids: List[str]
  62. :param where: to filter data
  63. :type where: Dict[str, any]
  64. :return: ids
  65. :rtype: Set[str]
  66. """
  67. existing_ids = list()
  68. if ids is not None:
  69. for i in range(0, len(ids), 1000):
  70. result = self.client.fetch(ids=ids[i : i + 1000])
  71. batch_existing_ids = list(result.get("vectors").keys())
  72. existing_ids.extend(batch_existing_ids)
  73. return {"ids": existing_ids}
  74. def add(
  75. self,
  76. embeddings: List[List[float]],
  77. documents: List[str],
  78. metadatas: List[object],
  79. ids: List[str],
  80. skip_embedding: bool,
  81. ):
  82. """add data in vector database
  83. :param documents: list of texts to add
  84. :type documents: List[str]
  85. :param metadatas: list of metadata associated with docs
  86. :type metadatas: List[object]
  87. :param ids: ids of docs
  88. :type ids: List[str]
  89. """
  90. docs = []
  91. if embeddings is None:
  92. embeddings = self.embedder.embedding_fn(documents)
  93. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  94. metadata["text"] = text
  95. docs.append(
  96. {
  97. "id": id,
  98. "values": embedding,
  99. "metadata": copy.deepcopy(metadata),
  100. }
  101. )
  102. for i in range(0, len(docs), self.BATCH_SIZE):
  103. self.client.upsert(docs[i : i + self.BATCH_SIZE])
  104. def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
  105. """
  106. query contents from vector database based on vector similarity
  107. :param input_query: list of query string
  108. :type input_query: List[str]
  109. :param n_results: no of similar documents to fetch from database
  110. :type n_results: int
  111. :param where: Optional. to filter data
  112. :type where: Dict[str, any]
  113. :return: Database contents that are the result of the query
  114. :rtype: List[str]
  115. """
  116. if not skip_embedding:
  117. query_vector = self.embedder.embedding_fn([input_query])[0]
  118. else:
  119. query_vector = input_query
  120. contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
  121. embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
  122. return embeddings
  123. def set_collection_name(self, name: str):
  124. """
  125. Set the name of the collection. A collection is an isolated space for vectors.
  126. :param name: Name of the collection.
  127. :type name: str
  128. """
  129. if not isinstance(name, str):
  130. raise TypeError("Collection name must be a string")
  131. self.config.collection_name = name
  132. def count(self) -> int:
  133. """
  134. Count number of documents/chunks embedded in the database.
  135. :return: number of documents
  136. :rtype: int
  137. """
  138. return self.client.describe_index_stats()["total_vector_count"]
  139. def _get_or_create_db(self):
  140. """Called during initialization"""
  141. return self.client
  142. def reset(self):
  143. """
  144. Resets the database. Deletes all embeddings irreversibly.
  145. """
  146. # Delete all data from the database
  147. pinecone.delete_index(self.index_name)
  148. self._setup_pinecone_index()
  149. # Pinecone only allows alphanumeric characters and "-" in the index name
  150. def _get_index_name(self) -> str:
  151. """Get the Pinecone index for a collection
  152. :return: Pinecone index
  153. :rtype: str
  154. """
  155. return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")