postgres.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import hashlib
  2. import logging
  3. from typing import Any, Optional
  4. from embedchain.loaders.base_loader import BaseLoader
  5. class PostgresLoader(BaseLoader):
  6. def __init__(self, config: Optional[dict[str, Any]] = None):
  7. super().__init__()
  8. if not config:
  9. raise ValueError(f"Must provide the valid config. Received: {config}")
  10. self.connection = None
  11. self.cursor = None
  12. self._setup_loader(config=config)
  13. def _setup_loader(self, config: dict[str, Any]):
  14. try:
  15. import psycopg
  16. except ImportError as e:
  17. raise ImportError(
  18. "Unable to import required packages. \
  19. Run `pip install --upgrade 'embedchain[postgres]'`"
  20. ) from e
  21. if "url" in config:
  22. config_info = config.get("url")
  23. else:
  24. conn_params = []
  25. for key, value in config.items():
  26. conn_params.append(f"{key}={value}")
  27. config_info = " ".join(conn_params)
  28. logging.info(f"Connecting to postrgres sql: {config_info}")
  29. self.connection = psycopg.connect(conninfo=config_info)
  30. self.cursor = self.connection.cursor()
  31. @staticmethod
  32. def _check_query(query):
  33. if not isinstance(query, str):
  34. raise ValueError(
  35. f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501
  36. )
  37. def load_data(self, query):
  38. self._check_query(query)
  39. try:
  40. data = []
  41. data_content = []
  42. self.cursor.execute(query)
  43. results = self.cursor.fetchall()
  44. for result in results:
  45. doc_content = str(result)
  46. data.append({"content": doc_content, "meta_data": {"url": query}})
  47. data_content.append(doc_content)
  48. doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
  49. return {
  50. "doc_id": doc_id,
  51. "data": data,
  52. }
  53. except Exception as e:
  54. raise ValueError(f"Failed to load data using query={query} with: {e}")
  55. def close_connection(self):
  56. if self.cursor:
  57. self.cursor.close()
  58. self.cursor = None
  59. if self.connection:
  60. self.connection.close()
  61. self.connection = None