mysql.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import hashlib
  2. import logging
  3. from typing import Any, Dict, Optional
  4. from embedchain.loaders.base_loader import BaseLoader
  5. from embedchain.utils import clean_string
  6. class MySQLLoader(BaseLoader):
  7. def __init__(self, config: Optional[Dict[str, Any]]):
  8. super().__init__()
  9. if not config:
  10. raise ValueError(
  11. f"Invalid sql config: {config}.",
  12. "Provide the correct config, refer `https://docs.embedchain.ai/data-sources/mysql`.",
  13. )
  14. self.config = config
  15. self.connection = None
  16. self.cursor = None
  17. self._setup_loader(config=config)
  18. def _setup_loader(self, config: Dict[str, Any]):
  19. try:
  20. import mysql.connector as sqlconnector
  21. except ImportError as e:
  22. raise ImportError(
  23. "Unable to import required packages for MySQL loader. Run `pip install --upgrade 'embedchain[mysql]'`." # noqa: E501
  24. ) from e
  25. try:
  26. self.connection = sqlconnector.connection.MySQLConnection(**config)
  27. self.cursor = self.connection.cursor()
  28. except (sqlconnector.Error, IOError) as err:
  29. logging.info(f"Connection failed: {err}")
  30. raise ValueError(
  31. f"Unable to connect with the given config: {config}.",
  32. "Please provide the correct configuration to load data from you MySQL DB. \
  33. Refer `https://docs.embedchain.ai/data-sources/mysql`.",
  34. )
  35. def _check_query(self, query):
  36. if not isinstance(query, str):
  37. raise ValueError(
  38. f"Invalid mysql query: {query}",
  39. "Provide the valid query to add from mysql, \
  40. make sure you are following `https://docs.embedchain.ai/data-sources/mysql`",
  41. )
  42. def load_data(self, query):
  43. self._check_query(query=query)
  44. data = []
  45. data_content = []
  46. self.cursor.execute(query)
  47. rows = self.cursor.fetchall()
  48. for row in rows:
  49. doc_content = clean_string(str(row))
  50. data.append({"content": doc_content, "meta_data": {"url": query}})
  51. data_content.append(doc_content)
  52. doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
  53. return {
  54. "doc_id": doc_id,
  55. "data": data,
  56. }