notion.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import hashlib
  2. import logging
  3. import os
  4. from typing import Any, Optional
  5. import requests
  6. from embedchain.helpers.json_serializable import register_deserializable
  7. from embedchain.loaders.base_loader import BaseLoader
  8. from embedchain.utils.misc import clean_string
  9. logger = logging.getLogger(__name__)
  10. class NotionDocument:
  11. """
  12. A simple Document class to hold the text and additional information of a page.
  13. """
  14. def __init__(self, text: str, extra_info: dict[str, Any]):
  15. self.text = text
  16. self.extra_info = extra_info
  17. class NotionPageLoader:
  18. """
  19. Notion Page Loader.
  20. Reads a set of Notion pages.
  21. """
  22. BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
  23. def __init__(self, integration_token: Optional[str] = None) -> None:
  24. """Initialize with Notion integration token."""
  25. if integration_token is None:
  26. integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
  27. if integration_token is None:
  28. raise ValueError(
  29. "Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`."
  30. )
  31. self.token = integration_token
  32. self.headers = {
  33. "Authorization": "Bearer " + self.token,
  34. "Content-Type": "application/json",
  35. "Notion-Version": "2022-06-28",
  36. }
  37. def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
  38. """Read a block from Notion."""
  39. done = False
  40. result_lines_arr = []
  41. cur_block_id = block_id
  42. while not done:
  43. block_url = self.BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
  44. res = requests.get(block_url, headers=self.headers)
  45. data = res.json()
  46. for result in data["results"]:
  47. result_type = result["type"]
  48. result_obj = result[result_type]
  49. cur_result_text_arr = []
  50. if "rich_text" in result_obj:
  51. for rich_text in result_obj["rich_text"]:
  52. if "text" in rich_text:
  53. text = rich_text["text"]["content"]
  54. prefix = "\t" * num_tabs
  55. cur_result_text_arr.append(prefix + text)
  56. result_block_id = result["id"]
  57. has_children = result["has_children"]
  58. if has_children:
  59. children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1)
  60. cur_result_text_arr.append(children_text)
  61. cur_result_text = "\n".join(cur_result_text_arr)
  62. result_lines_arr.append(cur_result_text)
  63. if data["next_cursor"] is None:
  64. done = True
  65. else:
  66. cur_block_id = data["next_cursor"]
  67. result_lines = "\n".join(result_lines_arr)
  68. return result_lines
  69. def load_data(self, page_ids: list[str]) -> list[NotionDocument]:
  70. """Load data from the given list of page IDs."""
  71. docs = []
  72. for page_id in page_ids:
  73. page_text = self._read_block(page_id)
  74. docs.append(NotionDocument(text=page_text, extra_info={"page_id": page_id}))
  75. return docs
  76. @register_deserializable
  77. class NotionLoader(BaseLoader):
  78. def load_data(self, source):
  79. """Load data from a Notion URL."""
  80. id = source[-32:]
  81. formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}"
  82. logger.debug(f"Extracted notion page id as: {formatted_id}")
  83. integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
  84. reader = NotionPageLoader(integration_token=integration_token)
  85. documents = reader.load_data(page_ids=[formatted_id])
  86. raw_text = documents[0].text
  87. text = clean_string(raw_text)
  88. doc_id = hashlib.sha256((text + source).encode()).hexdigest()
  89. return {
  90. "doc_id": doc_id,
  91. "data": [
  92. {
  93. "content": text,
  94. "meta_data": {"url": f"notion-{formatted_id}"},
  95. }
  96. ],
  97. }