Ver código fonte

[Chat PDF] update chat pdf logic (#1053)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 ano atrás
pai
commit
da388b679f
1 arquivos alterados com 28 adições e 17 exclusões
  1. 28 17
      examples/chat-pdf/app.py

+ 28 - 17
examples/chat-pdf/app.py

@@ -12,8 +12,7 @@ from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
                                           generate)
 
 
-@st.cache_resource
-def embedchain_bot():
+def embedchain_bot(db_path, api_key):
     return App.from_config(
         config={
             "llm": {
@@ -24,31 +23,43 @@ def embedchain_bot():
                     "max_tokens": 1000,
                     "top_p": 1,
                     "stream": True,
+                    "api_key": api_key,
                 },
             },
             "vectordb": {
                 "provider": "chroma",
-                "config": {"collection_name": "chat-pdf", "dir": "db", "allow_reset": True},
+                "config": {"collection_name": "chat-pdf", "dir": db_path, "allow_reset": True},
             },
+            "embedder": {"provider": "openai", "config": {"api_key": api_key}},
             "chunker": {"chunk_size": 2000, "chunk_overlap": 0, "length_function": "len"},
         }
     )
 
 
-@st.cache_data
-def update_openai_key():
-    os.environ["OPENAI_API_KEY"] = st.session_state.chatbot_api_key
+def get_db_path():
+    tmpdirname = tempfile.mkdtemp()
+    return tmpdirname
+
+
+def get_ec_app(api_key):
+    if "app" in st.session_state:
+        print("Found app in session state")
+        app = st.session_state.app
+    else:
+        print("Creating app")
+        db_path = get_db_path()
+        app = embedchain_bot(db_path, api_key)
+        st.session_state.app = app
+    return app
 
 
 with st.sidebar:
-    openai_access_token = st.text_input(
-        "OpenAI API Key", value=os.environ.get("OPENAI_API_KEY"), key="chatbot_api_key", type="password"
-    )  # noqa: E501
+    openai_access_token = st.text_input("OpenAI API Key", key="api_key", type="password")
     "WE DO NOT STORE YOUR OPENAI KEY."
     "Just paste your OpenAI API key here and we'll use it to power the chatbot. [Get your OpenAI API key](https://platform.openai.com/api-keys)"  # noqa: E501
 
-    if openai_access_token:
-        update_openai_key()
+    if st.session_state.api_key:
+        app = get_ec_app(st.session_state.api_key)
 
     pdf_files = st.file_uploader("Upload your PDF files", accept_multiple_files=True, type="pdf")
     add_pdf_files = st.session_state.get("add_pdf_files", [])
@@ -57,10 +68,9 @@ with st.sidebar:
         if file_name in add_pdf_files:
             continue
         try:
-            if not os.environ.get("OPENAI_API_KEY"):
+            if not st.session_state.api_key:
                 st.error("Please enter your OpenAI API Key")
                 st.stop()
-            app = embedchain_bot()
             temp_file_name = None
             with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=file_name, suffix=".pdf") as f:
                 f.write(pdf_file.getvalue())
@@ -97,11 +107,12 @@ for message in st.session_state.messages:
         st.markdown(message["content"])
 
 if prompt := st.chat_input("Ask me anything!"):
-    if not os.environ.get("OPENAI_API_KEY"):
+    if not st.session_state.api_key:
         st.error("Please enter your OpenAI API Key", icon="🤖")
         st.stop()
 
-    app = embedchain_bot()
+    app = get_ec_app(st.session_state.api_key)
+
     with st.chat_message("user"):
         st.session_state.messages.append({"role": "user", "content": prompt})
         st.markdown(prompt)
@@ -146,5 +157,5 @@ if prompt := st.chat_input("Ask me anything!"):
                 full_response += f"- {source}\n"
 
         msg_placeholder.markdown(full_response)
-        print("Answer: ", answer)
-        st.session_state.messages.append({"role": "assistant", "content": answer})
+        print("Answer: ", full_response)
+        st.session_state.messages.append({"role": "assistant", "content": full_response})