postgres.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import hashlib
  2. import logging
  3. from typing import Any, Dict, Optional
  4. from embedchain.loaders.base_loader import BaseLoader
  5. class PostgresLoader(BaseLoader):
  6. def __init__(self, config: Optional[Dict[str, Any]] = None):
  7. super().__init__()
  8. if not config:
  9. raise ValueError(f"Must provide the valid config. Received: {config}")
  10. self.connection = None
  11. self.cursor = None
  12. self._setup_loader(config=config)
  13. def _setup_loader(self, config: Dict[str, Any]):
  14. try:
  15. import psycopg
  16. except ImportError as e:
  17. raise ImportError(
  18. "Unable to import required packages. \
  19. Run `pip install --upgrade 'embedchain[postgres]'`"
  20. ) from e
  21. config_info = ""
  22. if "url" in config:
  23. config_info = config.get("url")
  24. else:
  25. conn_params = []
  26. for key, value in config.items():
  27. conn_params.append(f"{key}={value}")
  28. config_info = " ".join(conn_params)
  29. logging.info(f"Connecting to postrgres sql: {config_info}")
  30. self.connection = psycopg.connect(conninfo=config_info)
  31. self.cursor = self.connection.cursor()
  32. def _check_query(self, query):
  33. if not isinstance(query, str):
  34. raise ValueError(
  35. f"Invalid postgres query: {query}",
  36. "Provide the valid source to add from postgres, \
  37. make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",
  38. )
  39. def load_data(self, query):
  40. self._check_query(query)
  41. try:
  42. data = []
  43. data_content = []
  44. self.cursor.execute(query)
  45. results = self.cursor.fetchall()
  46. for result in results:
  47. doc_content = str(result)
  48. data.append({"content": doc_content, "meta_data": {"url": f"postgres_query-({query})"}})
  49. data_content.append(doc_content)
  50. doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
  51. return {
  52. "doc_id": doc_id,
  53. "data": data,
  54. }
  55. except Exception as e:
  56. raise ValueError(f"Failed to load data using query={query} with: {e}")
  57. def close_connection(self):
  58. if self.cursor:
  59. self.cursor.close()
  60. self.cursor = None
  61. if self.connection:
  62. self.connection.close()
  63. self.connection = None