app.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import os
  2. import queue
  3. import re
  4. import tempfile
  5. import threading
  6. import streamlit as st
  7. from embedchain import Pipeline as App
  8. from embedchain.config import BaseLlmConfig
  9. from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
  10. generate)
  11. @st.cache_resource
  12. def embedchain_bot():
  13. return App.from_config(
  14. config={
  15. "llm": {
  16. "provider": "openai",
  17. "config": {
  18. "model": "gpt-3.5-turbo-1106",
  19. "temperature": 0.5,
  20. "max_tokens": 1000,
  21. "top_p": 1,
  22. "stream": True,
  23. },
  24. },
  25. "vectordb": {
  26. "provider": "chroma",
  27. "config": {"collection_name": "chat-pdf", "dir": "db", "allow_reset": True},
  28. },
  29. "chunker": {"chunk_size": 2000, "chunk_overlap": 0, "length_function": "len"},
  30. }
  31. )
  32. @st.cache_data
  33. def update_openai_key():
  34. os.environ["OPENAI_API_KEY"] = st.session_state.chatbot_api_key
  35. with st.sidebar:
  36. openai_access_token = st.text_input(
  37. "OpenAI API Key", value=os.environ.get("OPENAI_API_KEY"), key="chatbot_api_key", type="password"
  38. ) # noqa: E501
  39. "WE DO NOT STORE YOUR OPENAI KEY."
  40. "Just paste your OpenAI API key here and we'll use it to power the chatbot. [Get your OpenAI API key](https://platform.openai.com/api-keys)" # noqa: E501
  41. if openai_access_token:
  42. update_openai_key()
  43. pdf_files = st.file_uploader("Upload your PDF files", accept_multiple_files=True, type="pdf")
  44. add_pdf_files = st.session_state.get("add_pdf_files", [])
  45. for pdf_file in pdf_files:
  46. file_name = pdf_file.name
  47. if file_name in add_pdf_files:
  48. continue
  49. try:
  50. if not os.environ.get("OPENAI_API_KEY"):
  51. st.error("Please enter your OpenAI API Key")
  52. st.stop()
  53. app = embedchain_bot()
  54. temp_file_name = None
  55. with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=file_name, suffix=".pdf") as f:
  56. f.write(pdf_file.getvalue())
  57. temp_file_name = f.name
  58. if temp_file_name:
  59. st.markdown(f"Adding {file_name} to knowledge base...")
  60. app.add(temp_file_name, data_type="pdf_file")
  61. st.markdown("")
  62. add_pdf_files.append(file_name)
  63. os.remove(temp_file_name)
  64. st.session_state.messages.append({"role": "assistant", "content": f"Added {file_name} to knowledge base!"})
  65. except Exception as e:
  66. st.error(f"Error adding {file_name} to knowledge base: {e}")
  67. st.stop()
  68. st.session_state["add_pdf_files"] = add_pdf_files
  69. st.title("📄 Embedchain - Chat with PDF")
  70. styled_caption = '<p style="font-size: 17px; color: #aaa;">🚀 An <a href="https://github.com/embedchain/embedchain">Embedchain</a> app powered by OpenAI!</p>' # noqa: E501
  71. st.markdown(styled_caption, unsafe_allow_html=True)
  72. if "messages" not in st.session_state:
  73. st.session_state.messages = [
  74. {
  75. "role": "assistant",
  76. "content": """
  77. Hi! I'm chatbot powered by Embedchain, which can answer questions about your pdf documents.\n
  78. Upload your pdf documents here and I'll answer your questions about them!
  79. """,
  80. }
  81. ]
  82. for message in st.session_state.messages:
  83. with st.chat_message(message["role"]):
  84. st.markdown(message["content"])
  85. if prompt := st.chat_input("Ask me anything!"):
  86. if not os.environ.get("OPENAI_API_KEY"):
  87. st.error("Please enter your OpenAI API Key", icon="🤖")
  88. st.stop()
  89. app = embedchain_bot()
  90. with st.chat_message("user"):
  91. st.session_state.messages.append({"role": "user", "content": prompt})
  92. st.markdown(prompt)
  93. with st.chat_message("assistant"):
  94. msg_placeholder = st.empty()
  95. msg_placeholder.markdown("Thinking...")
  96. full_response = ""
  97. q = queue.Queue()
  98. def app_response(result):
  99. llm_config = app.llm.config.as_dict()
  100. llm_config["callbacks"] = [StreamingStdOutCallbackHandlerYield(q=q)]
  101. config = BaseLlmConfig(**llm_config)
  102. answer, citations = app.chat(prompt, config=config, citations=True)
  103. result["answer"] = answer
  104. result["citations"] = citations
  105. results = {}
  106. thread = threading.Thread(target=app_response, args=(results,))
  107. thread.start()
  108. for answer_chunk in generate(q):
  109. full_response += answer_chunk
  110. msg_placeholder.markdown(full_response)
  111. thread.join()
  112. answer, citations = results["answer"], results["citations"]
  113. if citations:
  114. full_response += "\n\n**Sources**:\n"
  115. sources = []
  116. for i, citation in enumerate(citations):
  117. source = citation[1]
  118. pattern = re.compile(r"([^/]+)\.[^\.]+\.pdf$")
  119. match = pattern.search(source)
  120. if match:
  121. source = match.group(1) + ".pdf"
  122. sources.append(source)
  123. sources = list(set(sources))
  124. for source in sources:
  125. full_response += f"- {source}\n"
  126. msg_placeholder.markdown(full_response)
  127. print("Answer: ", answer)
  128. st.session_state.messages.append({"role": "assistant", "content": answer})