mysql.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import hashlib
  2. import logging
  3. from typing import Any, Optional
  4. from embedchain.loaders.base_loader import BaseLoader
  5. from embedchain.utils.misc import clean_string
  6. logger = logging.getLogger(__name__)
  7. class MySQLLoader(BaseLoader):
  8. def __init__(self, config: Optional[dict[str, Any]]):
  9. super().__init__()
  10. if not config:
  11. raise ValueError(
  12. f"Invalid sql config: {config}.",
  13. "Provide the correct config, refer `https://docs.embedchain.ai/data-sources/mysql`.",
  14. )
  15. self.config = config
  16. self.connection = None
  17. self.cursor = None
  18. self._setup_loader(config=config)
  19. def _setup_loader(self, config: dict[str, Any]):
  20. try:
  21. import mysql.connector as sqlconnector
  22. except ImportError as e:
  23. raise ImportError(
  24. "Unable to import required packages for MySQL loader. Run `pip install --upgrade 'embedchain[mysql]'`." # noqa: E501
  25. ) from e
  26. try:
  27. self.connection = sqlconnector.connection.MySQLConnection(**config)
  28. self.cursor = self.connection.cursor()
  29. except (sqlconnector.Error, IOError) as err:
  30. logger.info(f"Connection failed: {err}")
  31. raise ValueError(
  32. f"Unable to connect with the given config: {config}.",
  33. "Please provide the correct configuration to load data from you MySQL DB. \
  34. Refer `https://docs.embedchain.ai/data-sources/mysql`.",
  35. )
  36. @staticmethod
  37. def _check_query(query):
  38. if not isinstance(query, str):
  39. raise ValueError(
  40. f"Invalid mysql query: {query}",
  41. "Provide the valid query to add from mysql, \
  42. make sure you are following `https://docs.embedchain.ai/data-sources/mysql`",
  43. )
  44. def load_data(self, query):
  45. self._check_query(query=query)
  46. data = []
  47. data_content = []
  48. self.cursor.execute(query)
  49. rows = self.cursor.fetchall()
  50. for row in rows:
  51. doc_content = clean_string(str(row))
  52. data.append({"content": doc_content, "meta_data": {"url": query}})
  53. data_content.append(doc_content)
  54. doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
  55. return {
  56. "doc_id": doc_id,
  57. "data": data,
  58. }