base.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from embedchain.config.vectordb.base import BaseVectorDbConfig
  2. from embedchain.embedder.base import BaseEmbedder
  3. from embedchain.helpers.json_serializable import JSONSerializable
  4. class BaseVectorDB(JSONSerializable):
  5. """Base class for vector database."""
  6. def __init__(self, config: BaseVectorDbConfig):
  7. """Initialize the database. Save the config and client as an attribute.
  8. :param config: Database configuration class instance.
  9. :type config: BaseVectorDbConfig
  10. """
  11. self.client = self._get_or_create_db()
  12. self.config: BaseVectorDbConfig = config
  13. def _initialize(self):
  14. """
  15. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  16. So it's can't be done in __init__ in one step.
  17. """
  18. raise NotImplementedError
  19. def _get_or_create_db(self):
  20. """Get or create the database."""
  21. raise NotImplementedError
  22. def _get_or_create_collection(self):
  23. """Get or create a named collection."""
  24. raise NotImplementedError
  25. def _set_embedder(self, embedder: BaseEmbedder):
  26. """
  27. The database needs to access the embedder sometimes, with this method you can persistently set it.
  28. :param embedder: Embedder to be set as the embedder for this database.
  29. :type embedder: BaseEmbedder
  30. """
  31. self.embedder = embedder
  32. def get(self):
  33. """Get database embeddings by id."""
  34. raise NotImplementedError
  35. def add(self):
  36. """Add to database"""
  37. raise NotImplementedError
  38. def query(self):
  39. """Query contents from vector database based on vector similarity"""
  40. raise NotImplementedError
  41. def count(self) -> int:
  42. """
  43. Count number of documents/chunks embedded in the database.
  44. :return: number of documents
  45. :rtype: int
  46. """
  47. raise NotImplementedError
  48. def reset(self):
  49. """
  50. Resets the database. Deletes all embeddings irreversibly.
  51. """
  52. raise NotImplementedError
  53. def set_collection_name(self, name: str):
  54. """
  55. Set the name of the collection. A collection is an isolated space for vectors.
  56. :param name: Name of the collection.
  57. :type name: str
  58. """
  59. raise NotImplementedError
  60. def delete(self):
  61. """Delete from database."""
  62. raise NotImplementedError