123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import os
- import queue
- import re
- import tempfile
- import threading
- import streamlit as st
- from embedchain import App
- from embedchain.config import BaseLlmConfig
- from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
- generate)
- def embedchain_bot(db_path, api_key):
- return App.from_config(
- config={
- "llm": {
- "provider": "openai",
- "config": {
- "model": "gpt-3.5-turbo-1106",
- "temperature": 0.5,
- "max_tokens": 1000,
- "top_p": 1,
- "stream": True,
- "api_key": api_key,
- },
- },
- "vectordb": {
- "provider": "chroma",
- "config": {"collection_name": "chat-pdf", "dir": db_path, "allow_reset": True},
- },
- "embedder": {"provider": "openai", "config": {"api_key": api_key}},
- "chunker": {"chunk_size": 2000, "chunk_overlap": 0, "length_function": "len"},
- }
- )
- def get_db_path():
- tmpdirname = tempfile.mkdtemp()
- return tmpdirname
- def get_ec_app(api_key):
- if "app" in st.session_state:
- print("Found app in session state")
- app = st.session_state.app
- else:
- print("Creating app")
- db_path = get_db_path()
- app = embedchain_bot(db_path, api_key)
- st.session_state.app = app
- return app
- with st.sidebar:
- openai_access_token = st.text_input("OpenAI API Key", key="api_key", type="password")
- "WE DO NOT STORE YOUR OPENAI KEY."
- "Just paste your OpenAI API key here and we'll use it to power the chatbot. [Get your OpenAI API key](https://platform.openai.com/api-keys)" # noqa: E501
- if st.session_state.api_key:
- app = get_ec_app(st.session_state.api_key)
- pdf_files = st.file_uploader("Upload your PDF files", accept_multiple_files=True, type="pdf")
- add_pdf_files = st.session_state.get("add_pdf_files", [])
- for pdf_file in pdf_files:
- file_name = pdf_file.name
- if file_name in add_pdf_files:
- continue
- try:
- if not st.session_state.api_key:
- st.error("Please enter your OpenAI API Key")
- st.stop()
- temp_file_name = None
- with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=file_name, suffix=".pdf") as f:
- f.write(pdf_file.getvalue())
- temp_file_name = f.name
- if temp_file_name:
- st.markdown(f"Adding {file_name} to knowledge base...")
- app.add(temp_file_name, data_type="pdf_file")
- st.markdown("")
- add_pdf_files.append(file_name)
- os.remove(temp_file_name)
- st.session_state.messages.append({"role": "assistant", "content": f"Added {file_name} to knowledge base!"})
- except Exception as e:
- st.error(f"Error adding {file_name} to knowledge base: {e}")
- st.stop()
- st.session_state["add_pdf_files"] = add_pdf_files
- st.title("📄 Embedchain - Chat with PDF")
- styled_caption = '<p style="font-size: 17px; color: #aaa;">🚀 An <a href="https://github.com/embedchain/embedchain">Embedchain</a> app powered by OpenAI!</p>' # noqa: E501
- st.markdown(styled_caption, unsafe_allow_html=True)
- if "messages" not in st.session_state:
- st.session_state.messages = [
- {
- "role": "assistant",
- "content": """
- Hi! I'm chatbot powered by Embedchain, which can answer questions about your pdf documents.\n
- Upload your pdf documents here and I'll answer your questions about them!
- """,
- }
- ]
- for message in st.session_state.messages:
- with st.chat_message(message["role"]):
- st.markdown(message["content"])
- if prompt := st.chat_input("Ask me anything!"):
- if not st.session_state.api_key:
- st.error("Please enter your OpenAI API Key", icon="🤖")
- st.stop()
- app = get_ec_app(st.session_state.api_key)
- with st.chat_message("user"):
- st.session_state.messages.append({"role": "user", "content": prompt})
- st.markdown(prompt)
- with st.chat_message("assistant"):
- msg_placeholder = st.empty()
- msg_placeholder.markdown("Thinking...")
- full_response = ""
- q = queue.Queue()
- def app_response(result):
- llm_config = app.llm.config.as_dict()
- llm_config["callbacks"] = [StreamingStdOutCallbackHandlerYield(q=q)]
- config = BaseLlmConfig(**llm_config)
- answer, citations = app.chat(prompt, config=config, citations=True)
- result["answer"] = answer
- result["citations"] = citations
- results = {}
- thread = threading.Thread(target=app_response, args=(results,))
- thread.start()
- for answer_chunk in generate(q):
- full_response += answer_chunk
- msg_placeholder.markdown(full_response)
- thread.join()
- answer, citations = results["answer"], results["citations"]
- if citations:
- full_response += "\n\n**Sources**:\n"
- sources = []
- for i, citation in enumerate(citations):
- source = citation[1]["url"]
- pattern = re.compile(r"([^/]+)\.[^\.]+\.pdf$")
- match = pattern.search(source)
- if match:
- source = match.group(1) + ".pdf"
- sources.append(source)
- sources = list(set(sources))
- for source in sources:
- full_response += f"- {source}\n"
- msg_placeholder.markdown(full_response)
- print("Answer: ", full_response)
- st.session_state.messages.append({"role": "assistant", "content": full_response})
|