embedchain.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import chromadb
  2. import openai
  3. import os
  4. from chromadb.utils import embedding_functions
  5. from dotenv import load_dotenv
  6. from langchain.docstore.document import Document
  7. from langchain.embeddings.openai import OpenAIEmbeddings
  8. from embedchain.loaders.youtube_video import YoutubeVideoLoader
  9. from embedchain.loaders.pdf_file import PdfFileLoader
  10. from embedchain.loaders.website import WebsiteLoader
  11. from embedchain.chunkers.youtube_video import YoutubeVideoChunker
  12. from embedchain.chunkers.pdf_file import PdfFileChunker
  13. from embedchain.chunkers.website import WebsiteChunker
  14. load_dotenv()
  15. embeddings = OpenAIEmbeddings()
  16. ABS_PATH = os.getcwd()
  17. DB_DIR = os.path.join(ABS_PATH, "db")
  18. openai_ef = embedding_functions.OpenAIEmbeddingFunction(
  19. api_key=os.getenv("OPENAI_API_KEY"),
  20. model_name="text-embedding-ada-002"
  21. )
  22. class EmbedChain:
  23. def __init__(self):
  24. self.chromadb_client = self._get_or_create_db()
  25. self.collection = self._get_or_create_collection()
  26. self.user_asks = []
  27. def _get_loader(self, data_type):
  28. loaders = {
  29. 'youtube_video': YoutubeVideoLoader(),
  30. 'pdf_file': PdfFileLoader(),
  31. 'website': WebsiteLoader()
  32. }
  33. if data_type in loaders:
  34. return loaders[data_type]
  35. else:
  36. raise ValueError(f"Unsupported data type: {data_type}")
  37. def _get_chunker(self, data_type):
  38. chunkers = {
  39. 'youtube_video': YoutubeVideoChunker(),
  40. 'pdf_file': PdfFileChunker(),
  41. 'website': WebsiteChunker()
  42. }
  43. if data_type in chunkers:
  44. return chunkers[data_type]
  45. else:
  46. raise ValueError(f"Unsupported data type: {data_type}")
  47. def add(self, data_type, url):
  48. loader = self._get_loader(data_type)
  49. chunker = self._get_chunker(data_type)
  50. self.user_asks.append([data_type, url])
  51. self.load_and_embed(loader, chunker, url)
  52. def _get_or_create_db(self):
  53. client_settings = chromadb.config.Settings(
  54. chroma_db_impl="duckdb+parquet",
  55. persist_directory=DB_DIR,
  56. anonymized_telemetry=False
  57. )
  58. return chromadb.Client(client_settings)
  59. def _get_or_create_collection(self):
  60. return self.chromadb_client.get_or_create_collection(
  61. 'embedchain_store', embedding_function=openai_ef,
  62. )
  63. def load_embeddings_to_db(self, loader, chunker, url):
  64. embeddings_data = chunker.create_chunks(loader, url)
  65. documents = embeddings_data["documents"]
  66. metadatas = embeddings_data["metadatas"]
  67. ids = embeddings_data["ids"]
  68. self.collection.add(
  69. documents=documents,
  70. metadatas=metadatas,
  71. ids=ids
  72. )
  73. print(f"Docs count: {self.collection.count()}")
  74. def load_and_embed(self, loader, chunker, url):
  75. return self.load_embeddings_to_db(loader, chunker, url)
  76. def _format_result(self, results):
  77. return [
  78. (Document(page_content=result[0], metadata=result[1] or {}), result[2])
  79. for result in zip(
  80. results["documents"][0],
  81. results["metadatas"][0],
  82. results["distances"][0],
  83. )
  84. ]
  85. def get_openai_answer(self, prompt):
  86. messages = []
  87. messages.append({
  88. "role": "user", "content": prompt
  89. })
  90. response = openai.ChatCompletion.create(
  91. model="gpt-3.5-turbo-0613",
  92. messages=messages,
  93. temperature=0,
  94. max_tokens=1000,
  95. top_p=1,
  96. )
  97. return response["choices"][0]["message"]["content"]
  98. def get_answer_from_llm(self, query, context):
  99. prompt = f"""Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  100. {context}
  101. Query: {query}
  102. Helpful Answer:
  103. """
  104. answer = self.get_openai_answer(prompt)
  105. return answer
  106. def query(self, input_query):
  107. result = self.collection.query(
  108. query_texts=[input_query,],
  109. n_results=1,
  110. )
  111. result_formatted = self._format_result(result)
  112. answer = self.get_answer_from_llm(input_query, result_formatted[0][0].page_content)
  113. return answer
  114. class App(EmbedChain):
  115. pass