postgres.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import hashlib
  2. import logging
  3. from typing import Any, Optional
  4. from embedchain.loaders.base_loader import BaseLoader
  5. logger = logging.getLogger(__name__)
  6. class PostgresLoader(BaseLoader):
  7. def __init__(self, config: Optional[dict[str, Any]] = None):
  8. super().__init__()
  9. if not config:
  10. raise ValueError(f"Must provide the valid config. Received: {config}")
  11. self.connection = None
  12. self.cursor = None
  13. self._setup_loader(config=config)
  14. def _setup_loader(self, config: dict[str, Any]):
  15. try:
  16. import psycopg
  17. except ImportError as e:
  18. raise ImportError(
  19. "Unable to import required packages. \
  20. Run `pip install --upgrade 'embedchain[postgres]'`"
  21. ) from e
  22. if "url" in config:
  23. config_info = config.get("url")
  24. else:
  25. conn_params = []
  26. for key, value in config.items():
  27. conn_params.append(f"{key}={value}")
  28. config_info = " ".join(conn_params)
  29. logger.info(f"Connecting to postrgres sql: {config_info}")
  30. self.connection = psycopg.connect(conninfo=config_info)
  31. self.cursor = self.connection.cursor()
  32. @staticmethod
  33. def _check_query(query):
  34. if not isinstance(query, str):
  35. raise ValueError(
  36. f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501
  37. )
  38. def load_data(self, query):
  39. self._check_query(query)
  40. try:
  41. data = []
  42. data_content = []
  43. self.cursor.execute(query)
  44. results = self.cursor.fetchall()
  45. for result in results:
  46. doc_content = str(result)
  47. data.append({"content": doc_content, "meta_data": {"url": query}})
  48. data_content.append(doc_content)
  49. doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
  50. return {
  51. "doc_id": doc_id,
  52. "data": data,
  53. }
  54. except Exception as e:
  55. raise ValueError(f"Failed to load data using query={query} with: {e}")
  56. def close_connection(self):
  57. if self.cursor:
  58. self.cursor.close()
  59. self.cursor = None
  60. if self.connection:
  61. self.connection.close()
  62. self.connection = None