import os import queue import re import tempfile import threading import streamlit as st from embedchain import App from embedchain.config import BaseLlmConfig from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield, generate) def embedchain_bot(db_path, api_key): return App.from_config( config={ "llm": { "provider": "openai", "config": { "model": "gpt-3.5-turbo-1106", "temperature": 0.5, "max_tokens": 1000, "top_p": 1, "stream": True, "api_key": api_key, }, }, "vectordb": { "provider": "chroma", "config": {"collection_name": "chat-pdf", "dir": db_path, "allow_reset": True}, }, "embedder": {"provider": "openai", "config": {"api_key": api_key}}, "chunker": {"chunk_size": 2000, "chunk_overlap": 0, "length_function": "len"}, } ) def get_db_path(): tmpdirname = tempfile.mkdtemp() return tmpdirname def get_ec_app(api_key): if "app" in st.session_state: print("Found app in session state") app = st.session_state.app else: print("Creating app") db_path = get_db_path() app = embedchain_bot(db_path, api_key) st.session_state.app = app return app with st.sidebar: openai_access_token = st.text_input("OpenAI API Key", key="api_key", type="password") "WE DO NOT STORE YOUR OPENAI KEY." "Just paste your OpenAI API key here and we'll use it to power the chatbot. [Get your OpenAI API key](https://platform.openai.com/api-keys)" # noqa: E501 if st.session_state.api_key: app = get_ec_app(st.session_state.api_key) pdf_files = st.file_uploader("Upload your PDF files", accept_multiple_files=True, type="pdf") add_pdf_files = st.session_state.get("add_pdf_files", []) for pdf_file in pdf_files: file_name = pdf_file.name if file_name in add_pdf_files: continue try: if not st.session_state.api_key: st.error("Please enter your OpenAI API Key") st.stop() temp_file_name = None with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=file_name, suffix=".pdf") as f: f.write(pdf_file.getvalue()) temp_file_name = f.name if temp_file_name: st.markdown(f"Adding {file_name} to knowledge base...") app.add(temp_file_name, data_type="pdf_file") st.markdown("") add_pdf_files.append(file_name) os.remove(temp_file_name) st.session_state.messages.append({"role": "assistant", "content": f"Added {file_name} to knowledge base!"}) except Exception as e: st.error(f"Error adding {file_name} to knowledge base: {e}") st.stop() st.session_state["add_pdf_files"] = add_pdf_files st.title("📄 Embedchain - Chat with PDF") styled_caption = '

🚀 An Embedchain app powered by OpenAI!

' # noqa: E501 st.markdown(styled_caption, unsafe_allow_html=True) if "messages" not in st.session_state: st.session_state.messages = [ { "role": "assistant", "content": """ Hi! I'm chatbot powered by Embedchain, which can answer questions about your pdf documents.\n Upload your pdf documents here and I'll answer your questions about them! """, } ] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Ask me anything!"): if not st.session_state.api_key: st.error("Please enter your OpenAI API Key", icon="🤖") st.stop() app = get_ec_app(st.session_state.api_key) with st.chat_message("user"): st.session_state.messages.append({"role": "user", "content": prompt}) st.markdown(prompt) with st.chat_message("assistant"): msg_placeholder = st.empty() msg_placeholder.markdown("Thinking...") full_response = "" q = queue.Queue() def app_response(result): llm_config = app.llm.config.as_dict() llm_config["callbacks"] = [StreamingStdOutCallbackHandlerYield(q=q)] config = BaseLlmConfig(**llm_config) answer, citations = app.chat(prompt, config=config, citations=True) result["answer"] = answer result["citations"] = citations results = {} thread = threading.Thread(target=app_response, args=(results,)) thread.start() for answer_chunk in generate(q): full_response += answer_chunk msg_placeholder.markdown(full_response) thread.join() answer, citations = results["answer"], results["citations"] if citations: full_response += "\n\n**Sources**:\n" sources = [] for i, citation in enumerate(citations): source = citation[1]["url"] pattern = re.compile(r"([^/]+)\.[^\.]+\.pdf$") match = pattern.search(source) if match: source = match.group(1) + ".pdf" sources.append(source) sources = list(set(sources)) for source in sources: full_response += f"- {source}\n" msg_placeholder.markdown(full_response) print("Answer: ", full_response) st.session_state.messages.append({"role": "assistant", "content": full_response})