mysql.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import hashlib
  2. import logging
  3. from typing import Any, Optional
  4. from embedchain.loaders.base_loader import BaseLoader
  5. from embedchain.utils.misc import clean_string
  6. class MySQLLoader(BaseLoader):
  7. def __init__(self, config: Optional[dict[str, Any]]):
  8. super().__init__()
  9. if not config:
  10. raise ValueError(
  11. f"Invalid sql config: {config}.",
  12. "Provide the correct config, refer `https://docs.embedchain.ai/data-sources/mysql`.",
  13. )
  14. self.config = config
  15. self.connection = None
  16. self.cursor = None
  17. self._setup_loader(config=config)
  18. def _setup_loader(self, config: dict[str, Any]):
  19. try:
  20. import mysql.connector as sqlconnector
  21. except ImportError as e:
  22. raise ImportError(
  23. "Unable to import required packages for MySQL loader. Run `pip install --upgrade 'embedchain[mysql]'`." # noqa: E501
  24. ) from e
  25. try:
  26. self.connection = sqlconnector.connection.MySQLConnection(**config)
  27. self.cursor = self.connection.cursor()
  28. except (sqlconnector.Error, IOError) as err:
  29. logging.info(f"Connection failed: {err}")
  30. raise ValueError(
  31. f"Unable to connect with the given config: {config}.",
  32. "Please provide the correct configuration to load data from you MySQL DB. \
  33. Refer `https://docs.embedchain.ai/data-sources/mysql`.",
  34. )
  35. @staticmethod
  36. def _check_query(query):
  37. if not isinstance(query, str):
  38. raise ValueError(
  39. f"Invalid mysql query: {query}",
  40. "Provide the valid query to add from mysql, \
  41. make sure you are following `https://docs.embedchain.ai/data-sources/mysql`",
  42. )
  43. def load_data(self, query):
  44. self._check_query(query=query)
  45. data = []
  46. data_content = []
  47. self.cursor.execute(query)
  48. rows = self.cursor.fetchall()
  49. for row in rows:
  50. doc_content = clean_string(str(row))
  51. data.append({"content": doc_content, "meta_data": {"url": query}})
  52. data_content.append(doc_content)
  53. doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
  54. return {
  55. "doc_id": doc_id,
  56. "data": data,
  57. }