chain.py 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371
  1. # -*- coding: utf-8 -*-
  2. # @Author: privacy
  3. # @Date: 2024-11-21 13:13:03
  4. # @Last Modified by: privacy
  5. # @Last Modified time: 2024-12-03 13:03:40
  6. import os
  7. import json
  8. from pathlib import Path
  9. from typing import Iterator, Optional, Union
  10. from langchain_core.documents import Document
  11. from langchain_community.document_loaders.base import BaseLoader
  12. from langchain_community.document_loaders import TextLoader
  13. from langchain_community.document_loaders.csv_loader import CSVLoader
  14. from langchain_community.document_loaders import DirectoryLoader
  15. from langchain_community.document_loaders import UnstructuredHTMLLoader
  16. from langchain_community.document_loaders import UnstructuredMarkdownLoader
  17. from langchain_community.document_loaders import PyPDFLoader
  18. from langchain_community.document_loaders import PyMuPDFLoader
  19. from langchain_community.document_loaders import PyPDFDirectoryLoader
  20. from langchain_community.document_loaders import UnstructuredExcelLoader
  21. from langchain_community.document_loaders import UnstructuredPowerPointLoader
  22. from langchain_community.document_loaders import UnstructuredWordDocumentLoader
  23. from langchain_community.document_loaders import UnstructuredImageLoader
  24. from langchain_community.document_transformers import EmbeddingsRedundantFilter
  25. from langchain_community.document_transformers import LongContextReorder
  26. from langchain.text_splitter import CharacterTextSplitter
  27. from langchain_text_splitters import Language
  28. from langchain_text_splitters import RecursiveCharacterTextSplitter
  29. from langchain_text_splitters import RecursiveJsonSplitter
  30. from langchain_text_splitters import HTMLHeaderTextSplitter
  31. from langchain_text_splitters import MarkdownHeaderTextSplitter
  32. from langchain_huggingface import HuggingFaceEmbeddings
  33. from langchain_chroma import Chroma
  34. from langchain_community.vectorstores import FAISS
  35. from langchain.retrievers import ContextualCompressionRetriever
  36. from langchain.retrievers.multi_query import MultiQueryRetriever
  37. from langchain.retrievers.document_compressors import LLMChainFilter
  38. from langchain.retrievers.document_compressors import EmbeddingsFilter
  39. from langchain.retrievers.document_compressors import DocumentCompressorPipeline
  40. from langchain.storage import InMemoryStore
  41. from langchain.retrievers import EnsembleRetriever
  42. from langchain_community.retrievers import BM25Retriever
  43. from langchain.retrievers import ParentDocumentRetriever
  44. # from langchain.retrievers import KNNRetriever
  45. from langchain_community.retrievers import KNNRetriever
  46. # from langchain.retrievers import TFIDFRetriever
  47. from langchain_community.retrievers import TFIDFRetriever
  48. from langchain_core.messages import SystemMessage
  49. from langchain_core.messages import AIMessage, HumanMessage
  50. from langchain_core.prompts import PromptTemplate
  51. from langchain_core.prompts import ChatPromptTemplate
  52. from langchain_core.prompts import HumanMessagePromptTemplate
  53. from langchain_core.prompts import ChatMessagePromptTemplate
  54. from langchain_core.prompts import MessagesPlaceholder
  55. from langchain_core.prompts import FewShotChatMessagePromptTemplate
  56. from langchain_core.prompts import FewShotPromptTemplate
  57. from langchain_core.example_selectors import LengthBasedExampleSelector
  58. from langchain_core.example_selectors import MaxMarginalRelevanceExampleSelector
  59. from langchain_core.example_selectors import SemanticSimilarityExampleSelector
  60. from langchain_core.output_parsers import JsonOutputParser
  61. from langchain.chains import LLMChain
  62. from langchain.output_parsers import CommaSeparatedListOutputParser
  63. from langchain.output_parsers import DatetimeOutputParser
  64. from langchain_community.agent_toolkits.load_tools import load_tools
  65. from langchain_community.utilities import TextRequestsWrapper
  66. from langchain_experimental.text_splitter import SemanticChunker
  67. from langchain_experimental.utilities import PythonREPL
  68. from langchain.agents import AgentType, initialize_agent
  69. class TextMindLoader(BaseLoader):
  70. def __init__(
  71. self,
  72. file_path: Union[str, Path]
  73. ):
  74. self.file_path = file_path
  75. def lazy_load(self) -> Iterator[Document]:
  76. with open(self.file_path, 'r', encoding='utf-8') as jsonfile:
  77. raw_json = json.load(jsonfile)
  78. for page in raw_json['pages']:
  79. page_content = page['text']
  80. metadata = {'page_num': page['page_num'], 'has_table': page['tables'] is not None}
  81. yield Document(page_content=page_content, metadata=metadata)
  82. # data = get_ocr_new(raw=raw_json, pretty=True)
  83. # for title in data['title']:
  84. # page_content = title['text']
  85. # metadata = title
  86. # yield Document(page_content=page_content, metadata=metadata)
  87. class DocsLoader():
  88. @classmethod
  89. def txt_loader(cls, filepath):
  90. """
  91. 加载 txt 数据
  92. :param filepath:
  93. :return:
  94. """
  95. loader = TextLoader(filepath, encoding='utf8')
  96. docs = loader.load()
  97. return docs
  98. @classmethod
  99. def csv_loader(cls, filepath):
  100. """
  101. https://python.langchain.com/docs/modules/data_connection/document_loaders/csv/
  102. 可用参数解释:https://blog.csdn.net/zjkpy_5/article/details/137727850?spm=1001.2014.3001.5501
  103. 加载 csv 数据
  104. :param filepath:
  105. :return:
  106. """""
  107. loader = CSVLoader(file_path=filepath, encoding='utf8')
  108. docs = loader.load()
  109. return docs
  110. @classmethod
  111. def json_loader(cls, filepath):
  112. """
  113. https://python.langchain.com/docs/modules/data_connection/document_loaders/json/
  114. 官网 jq 用不了 win 系统
  115. 加载 json 数据
  116. :param filepath:
  117. :return:
  118. """
  119. docs = json.loads(Path(filepath).read_text(encoding='utf8'))
  120. return docs
  121. @classmethod
  122. def file_directory_loader(cls, filepath, glob="**/[!.]*", loader_cls=TextLoader, silent_errors=False, show_progress=True, use_multithreading=True, max_concurrency=4, exclude=[], recursive=True):
  123. """
  124. https://python.langchain.com/docs/modules/data_connection/document_loaders/file_directory/
  125. 根据目录加载里面所有数据,不会加载文件.rst或.html文件
  126. :param filepath:
  127. :param glob: 默认加载所有非隐藏文件
  128. *.txt:只加载所有 txt
  129. :param loader_cls: 加载器,默认是 UnstructuredFileLoader,可以指定文本加载器(TextLoader)避免编码报错
  130. :param autodetect_encoding: 自动检测编码
  131. :param silent_errors: 跳过无法加载的文件并继续加载过程
  132. :param show_progress: 显示进度条
  133. :param use_multithreading: 多线程开启加载
  134. :param max_concurrency: 线程数量
  135. :param exclude: 指定不加的文件格式,列表格式
  136. :param recursive: 递归加载文件,目录下还有文件夹,加载里面的文件
  137. :return:
  138. """
  139. text_loader_kwargs = {'autodetect_encoding': True}
  140. loader = DirectoryLoader(filepath, glob=glob, loader_cls=loader_cls, silent_errors=silent_errors,
  141. loader_kwargs=text_loader_kwargs, show_progress=show_progress,
  142. use_multithreading=use_multithreading, max_concurrency=max_concurrency,
  143. exclude=exclude, recursive=recursive)
  144. docs = loader.load()
  145. return docs
  146. @classmethod
  147. def html_loader(cls, filpath):
  148. """
  149. https://python.langchain.com/docs/modules/data_connection/document_loaders/html/
  150. 加载 html
  151. 官网 BSHTMLLoader 会报编码错
  152. 其他加载方式是利用爬虫,第三方的,需要申请 api
  153. :param filpath:
  154. :return: 网页中的文本
  155. """
  156. loader = UnstructuredHTMLLoader(filpath)
  157. data = loader.load()
  158. return data
  159. @classmethod
  160. def markdown_loader(cls, filepath, mode='single'):
  161. """
  162. https://python.langchain.com/docs/modules/data_connection/document_loaders/markdown/
  163. 加载 markdown
  164. :param filepath:
  165. :param mode: 分割模式,single 全部合在一起,elements 把每一块都单独分开
  166. :return:
  167. """
  168. loader = UnstructuredMarkdownLoader(filepath, mode=mode)
  169. data = loader.load()
  170. return data
  171. @classmethod
  172. def pdf_loader(cls, filepath, extract_images=True, is_directory=False):
  173. """
  174. https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf/
  175. 加载 pdf,默认 page 是页码,但可能多出几页
  176. :param filepath:
  177. :param extract_images: 默认提取图片文字,是否提取 pdf 中的图片的文字
  178. :param is_directory: 如果传入进来是目录,加载此路径下的所有 pdf,但图片中的文字不能识别
  179. :return:
  180. """
  181. if is_directory:
  182. filepath = is_directory
  183. loader = PyPDFDirectoryLoader(filepath)
  184. docs = loader.load()
  185. return docs
  186. else:
  187. if extract_images:
  188. loader = PyPDFLoader(filepath, extract_images=extract_images)
  189. else:
  190. loader = PyMuPDFLoader(filepath) # 最快的 PDF 解析选项,但不能提取图片中的文字
  191. pages = loader.load_and_split()
  192. return pages
  193. @classmethod
  194. def excel_loader(cls, filepath, mode='single'):
  195. """
  196. https://python.langchain.com/docs/integrations/document_loaders/microsoft_excel/
  197. excel 加载,处理.xlsx和.xls文件
  198. :param filepath:
  199. :param mode: 式下使用加载程序 "elements",则该键下的文档元数据中将提供 Excel 文件的 HTML 表示形式text_as_html
  200. :return:
  201. """
  202. loader = UnstructuredExcelLoader(filepath, mode='elements')
  203. docs = loader.load()
  204. return docs
  205. @classmethod
  206. def ppt_loader(cls, filepath, mode='single'):
  207. """
  208. https://python.langchain.com/docs/integrations/document_loaders/microsoft_powerpoint/
  209. 加载 ppt,不能提取图片中的文字
  210. :param filepath:
  211. :param mode: 分割模式,single 全部合在一起,elements 把每一页的文本框,表格等都单独分开
  212. :return:
  213. """
  214. loader = UnstructuredPowerPointLoader(filepath, mode=mode)
  215. data = loader.load()
  216. return data
  217. @classmethod
  218. def word_loader(cls, filepath, mode='single'):
  219. """
  220. https://python.langchain.com/docs/integrations/document_loaders/microsoft_word/
  221. :param filepath:
  222. :param mode: 分割模式,single 全部合在一起,elements 把每一页单独分开,不能识别图片文字
  223. :return:
  224. """
  225. loader = UnstructuredWordDocumentLoader(filepath, mode=mode)
  226. data = loader.load()
  227. return data
  228. @classmethod
  229. def img_loader(cls, filepath, mode='single'):
  230. """
  231. https://python.langchain.com/docs/integrations/document_loaders/image/
  232. 加载图片,可以识别上面文字,但不一定准
  233. 报错:no modul pdfminer.utils:https://github.com/langchain-ai/langchain/issues/14326
  234. :param filepath:
  235. :param mode: single-所有文字合在一起,elements-每个文字单独分开为一个快
  236. :return:
  237. """
  238. loader = UnstructuredImageLoader(filepath, mode=mode)
  239. data = loader.load()
  240. return data
  241. @classmethod
  242. def textmind_loader(cls, filepath):
  243. loader = TextMindLoader(file_path=filepath)
  244. data = loader.load()
  245. return data
  246. class TextSpliter():
  247. @classmethod
  248. def text_split_by_char(cls, docs, separator='\n', chunk_size=100, chunk_overlap=20, length_function=len, is_separator_regex=False):
  249. """
  250. https://python.langchain.com/docs/modules/data_connection/document_transformers/character_text_splitter/
  251. 指定字符拆分,separator 指定,若指定有效 chunk_size 失效
  252. :param docs: 文档,必须为 str,如果是 langchain 加载进来的需要转换一下
  253. :param separator: 分割字符
  254. :param chunk_size: 每块大小
  255. :param chunk_overlap: 允许字数重叠大小
  256. :param length_function:
  257. :param is_separator_regex:
  258. :return:
  259. """
  260. text_splitter = CharacterTextSplitter(
  261. separator=separator,
  262. chunk_size=chunk_size,
  263. chunk_overlap=chunk_overlap,
  264. length_function=length_function,
  265. is_separator_regex=is_separator_regex,
  266. )
  267. docs = docs[0].page_content # langchian 加载的 txt 转换为 str
  268. text_split = text_splitter.create_documents([docs])
  269. return text_split
  270. @classmethod
  271. def text_split_by_manychar_or_charnum(cls, docs, separator=["\n\n", "\n", " ", ""], chunk_size=100, chunk_overlap=20, length_function=len, is_separator_regex=True):
  272. """
  273. https://python.langchain.com/docs/modules/data_connection/document_transformers/recursive_text_splitter/
  274. 按照 chunk_size 字数分割,separator 不需要传,保持默认值即可
  275. 多个字符拆分,separator 指定,符合列表中的字符就会被拆分
  276. :param docs: 文档,必须为 str,如果是 langchain 加载进来的需要转换一下
  277. :param separator: 分割字符,默认以列表中的字符去分割 ["\n\n", "\n", " ", ""]
  278. :param chunk_size: 每块大小
  279. :param chunk_overlap: 允许字数重叠大小
  280. :param length_function:
  281. :param is_separator_regex:
  282. :return:
  283. """
  284. text_splitter = RecursiveCharacterTextSplitter(
  285. chunk_size=chunk_size, # 指定每块大小
  286. chunk_overlap=chunk_overlap, # 指定每块可以重叠的字符数
  287. length_function=length_function,
  288. is_separator_regex=is_separator_regex,
  289. separators=separator # 指定按照什么字符去分割,如果不指定就按照 chunk_size +- chunk_overlap(100+-20)个字去分割
  290. )
  291. docs = docs[0].page_content # langchian 加载的 txt 转换为 str
  292. split_text = text_splitter.create_documents([docs])
  293. return split_text
  294. @classmethod
  295. def json_split(cls, json_data, min_chunk_size=50, max_chunk_size=300):
  296. """
  297. https://python.langchain.com/docs/modules/data_connection/document_transformers/recursive_json_splitter/
  298. json 拆分,每一个块会拆分为完整的字典
  299. :param json_data:
  300. :param min_chunk_size:
  301. :param max_chunk_size:
  302. :return:
  303. """
  304. splitter = RecursiveJsonSplitter(min_chunk_size=min_chunk_size, max_chunk_size=max_chunk_size)
  305. json_chunks = splitter.split_json(json_data=json_data)
  306. return json_chunks
  307. @classmethod
  308. def html_split(cls, html_string='', url='', chunk_size=500, chunk_overlap=30):
  309. """
  310. https://python.langchain.com/docs/modules/data_connection/document_transformers/HTML_header_metadata/
  311. html 分割,两种方式
  312. :param html_string: 字符串类型 html
  313. :param url: 传入 url 分割 html
  314. :return:
  315. """
  316. # 按照标题标签分,相同的 h 标签会在元数据可以看到属于哪一个 h
  317. headers_to_split_on = [
  318. ("h1", "Header 1"),
  319. ("h2", "Header 2"),
  320. ("h3", "Header 3"),
  321. ("h4", "Header 4"),
  322. ("h5", "Header 5"),
  323. ("h6", "Header 6"),
  324. ]
  325. html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
  326. if html_string:
  327. splits = html_splitter.split_text(html_string)
  328. else:
  329. html_header_splits = html_splitter.split_text_from_url(url)
  330. text_splitter = RecursiveCharacterTextSplitter(
  331. chunk_size=chunk_size, chunk_overlap=chunk_overlap
  332. )
  333. splits = text_splitter.split_documents(html_header_splits)
  334. return splits
  335. @classmethod
  336. def code_split(cls, code, language=Language.PYTHON, chunk_size=50, chunk_overlap=0):
  337. """
  338. https://python.langchain.com/docs/modules/data_connection/document_transformers/code_splitter/
  339. # Full list of supported languages
  340. [e.value for e in Language]
  341. 分割代码
  342. :param code:
  343. :param language: 默认 python
  344. :param chunk_size:
  345. :param chunk_overlap:
  346. :return:
  347. """
  348. python_splitter = RecursiveCharacterTextSplitter.from_language(
  349. language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap
  350. )
  351. docs = python_splitter.create_documents([code])
  352. return docs
  353. @classmethod
  354. def markdown_split(cls, mkardown_string, char_level_splits=False, strip_headers=False, chunk_size=250, chunk_overlap=30):
  355. """
  356. https://python.langchain.com/docs/modules/data_connection/document_transformers/markdown_header_metadata/
  357. 分割 markdown
  358. :param mkardown_string: markdown 字符串
  359. :param char_level_splits: 是否在标题分割后再继续按字数分割
  360. :param strip_headers: 默认情况下,从输出块的内容中删除分割的标头。可以通过设置禁用此功能 strip_headers = False。
  361. :return:
  362. """
  363. headers_to_split_on = [
  364. ("#", "Header 1"),
  365. ("##", "Header 2"),
  366. ("###", "Header 3"),
  367. ("####", "Header 4"),
  368. ("#####", "Header 5"),
  369. ("######", "Header 6"),
  370. ]
  371. markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on,
  372. strip_headers=strip_headers)
  373. md_header_splits = markdown_splitter.split_text(mkardown_string)
  374. splits = md_header_splits
  375. if char_level_splits:
  376. text_splitter = RecursiveCharacterTextSplitter(
  377. chunk_size=chunk_size, chunk_overlap=chunk_overlap
  378. )
  379. splits = text_splitter.split_documents(md_header_splits)
  380. return splits
  381. @classmethod
  382. def semantic_chunker_split(cls, txt, embedding_model, breakpoint_threshold_type="percentile"):
  383. """
  384. https://python.langchain.com/docs/modules/data_connection/document_transformers/semantic-chunker/
  385. 语义分块
  386. :param txt: txt 字符串
  387. :param embedding_model:
  388. :param breakpoint_threshold_type: 分割断点
  389. percentile:默认的分割方式是基于百分位数。在此方法中,计算句子之间的所有差异,然后分割任何大于 X 百分位数的差异
  390. standard_deviation:任何大于 X 个标准差的差异都会被分割。
  391. interquartile:使用四分位数距离来分割块
  392. :return:
  393. """
  394. text_splitter = SemanticChunker(embedding_model, breakpoint_threshold_type=breakpoint_threshold_type)
  395. docs = text_splitter.create_documents([txt])
  396. return docs
  397. class EmbeddingVectorDB():
  398. @classmethod
  399. def load_local_embedding_model(cls, embedding_model_path, device='cpu'):
  400. """加载本地向量模型"""
  401. embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_path, model_kwargs={'device': device})
  402. return embedding_model
  403. @classmethod
  404. def faiss_vector_db(cls, split_docs, vector_db_path, embedding_model):
  405. """
  406. https://python.langchain.com/docs/modules/data_connection/vectorstores/
  407. faiss 创建向量数据库
  408. :param split_docs: 分割的文本块
  409. :param vector_db_path: 向量数据库存储路径
  410. :param embedding_model: embedding 模型
  411. :return:
  412. """
  413. if os.path.exists(vector_db_path):
  414. print('加载向量数据库路径 =》', vector_db_path)
  415. db = FAISS.load_local(vector_db_path, embedding_model, allow_dangerous_deserialization=True)
  416. else:
  417. print('创建向量数据库路径 =》', vector_db_path)
  418. db = FAISS.from_documents(split_docs, embedding_model)
  419. db.save_local(vector_db_path)
  420. return db
  421. @classmethod
  422. async def faiss_vector_db_await(cls, split_docs, vector_db_path, embedding_model):
  423. """
  424. https://python.langchain.com/docs/integrations/vectorstores/faiss_async/#similarity-search-with-score
  425. :param split_docs: 分割的文本块
  426. :param vector_db_path: 向量数据库存储路径
  427. :param embedding_model: embedding 模型
  428. :return:
  429. """
  430. if os.path.exists(vector_db_path):
  431. print('加载向量数据库路径 =》', vector_db_path)
  432. db = FAISS.load_local(vector_db_path, embedding_model, allow_dangerous_deserialization=True)
  433. else:
  434. print('创建向量数据库路径 =》', vector_db_path)
  435. db = await FAISS.afrom_documents(split_docs, embedding_model)
  436. db.save_local(vector_db_path)
  437. return db
  438. @classmethod
  439. def chroma_vector_db(cls, split_docs, vector_db_path, embedding_model):
  440. """
  441. https://python.langchain.com/docs/modules/data_connection/vectorstores/
  442. faiss 创建向量数据库
  443. :param split_docs: 分割的文本块
  444. :param vector_db_path: 向量数据库存储路径
  445. :param embedding_model: embedding 模型
  446. :return:
  447. """
  448. if os.path.exists(vector_db_path):
  449. print('加载向量数据库路径 =》', vector_db_path)
  450. db = Chroma(persist_directory=vector_db_path, embedding_function=embedding_model)
  451. else:
  452. print('创建向量数据库路径 =》', vector_db_path)
  453. db = Chroma.from_documents(split_docs, embedding_model, persist_directory=vector_db_path)
  454. # db.persist()
  455. return db
  456. class Retriever():
  457. @classmethod
  458. def similarity(cls, db, query, topk=5, long_context=False):
  459. """
  460. https://python.langchain.com/docs/modules/data_connection/retrievers/vectorstore/
  461. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  462. 相似度,不带分数的,会把检索出所有最相似的返回,如果文档中有重复的,那会返回重复的
  463. :param db:
  464. :param query:
  465. :param long_context: 长上下文排序
  466. :return:
  467. """
  468. retriever = db.as_retriever(search_kwargs={'k': topk})
  469. retriever_docs = retriever.get_relevant_documents(query)
  470. if long_context:
  471. reordering = LongContextReorder()
  472. retriever_docs = reordering.transform_documents(retriever_docs)
  473. return retriever_docs
  474. @classmethod
  475. def similarity_with_score(cls, db, query, topk=5, long_context=False):
  476. """
  477. https://python.langchain.com/docs/integrations/vectorstores/usearch/#similarity-search-with-score
  478. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  479. 带分数的,距离分数是L2距离。因此,分数越低越好
  480. :param db:
  481. :param query:
  482. :param long_context: 长上下文排序
  483. :return:
  484. """
  485. retriever_docs = db.similarity_search_with_score(query, k=topk)
  486. if long_context:
  487. reordering = LongContextReorder()
  488. retriever_docs = reordering.transform_documents(retriever_docs)
  489. return retriever_docs
  490. @classmethod
  491. def mmr(cls, db, query, topk=5, fetch_k=50, long_context=False):
  492. """
  493. https://python.langchain.com/docs/modules/data_connection/retrievers/vectorstore/
  494. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  495. mmr 算法会去重,会把检索出所有最相似的返回
  496. :param db:
  497. :param query:
  498. :param topk: 指定最相似的返回几个, 最多返回的数量不会超过 fetch_k
  499. :param fetch_k: 给 mmr 的最多文档数
  500. :param long_context: 长上下文排序
  501. :return:
  502. """
  503. retriever = db.as_retriever(search_type="mmr", ssearch_kwargs={'k': topk, 'fetch_k': fetch_k})
  504. retriever_docs = retriever.get_relevant_documents(query)
  505. if long_context:
  506. reordering = LongContextReorder()
  507. retriever_docs = reordering.transform_documents(retriever_docs)
  508. return retriever_docs
  509. @classmethod
  510. def similarity_score_threshold(cls, db, query, topk=5, score_threshold=0.8, long_context=False):
  511. """
  512. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  513. 相似分数过滤
  514. :param db:
  515. :param query:
  516. :param topk:
  517. :param score_threshold: 相似分数
  518. :param long_context: 长上下文排序
  519. :return:
  520. """
  521. retriever = db.as_retriever(search_type="similarity_score_threshold",
  522. search_kwargs={'k': topk, "score_threshold": score_threshold})
  523. retriever_docs = retriever.get_relevant_documents(query)
  524. if long_context:
  525. reordering = LongContextReorder()
  526. retriever_docs = reordering.transform_documents(retriever_docs)
  527. return retriever_docs
  528. @classmethod
  529. def multi_query_retriever(cls, db, query, model, topk=5, long_context=False):
  530. """
  531. https://python.langchain.com/docs/modules/data_connection/retrievers/MultiQueryRetriever/
  532. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  533. 多查询检索器
  534. 基于向量距离的检索可能因微小的询问词变化或向量无法准确表达语义而产生不同结果;
  535. 使用大预言模型自动从不同角度生成多个查询,实现提示词优化;
  536. 对用户查询生成表达其不同方面的多个新查询(也就是query利用大模型生成多个表述),对每个表述进行检索,去结果的并集;
  537. 优点是生成的查询多角度,可以覆盖更全面的语义和信息需求;
  538. 指定 topk 好像没用,不知道为什么
  539. :param db:
  540. :param query:
  541. :param long_context: 长上下文排序
  542. :return:
  543. """
  544. retriever = db.as_retriever(search_kwargs={'k': topk})
  545. retriever = MultiQueryRetriever.from_llm(retriever=retriever, llm=model)
  546. retriever_docs = retriever.get_relevant_documents(query=query)
  547. if long_context:
  548. reordering = LongContextReorder()
  549. retriever_docs = reordering.transform_documents(retriever_docs)
  550. return retriever_docs
  551. @classmethod
  552. def contextual_compression_by_llm(cls, db, query, model, topk=5, long_context=False):
  553. """
  554. https://python.langchain.com/docs/modules/data_connection/retrievers/contextual_compression/
  555. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  556. 上下文压缩检索器,大模型,会对结果去重
  557. 使用给定查询的上下文来压缩检索的输出,以便只返回相关信息,而不是立即按照原样返回检索到的文档
  558. 相当于提取每个检索结果的核心,简化每个文档,利用大模型的能力
  559. 不知道为什么 topk 不管用
  560. :param db:
  561. :param query:
  562. :param model:
  563. :param topk:
  564. :param long_context: 长上下文排序
  565. :return:
  566. """
  567. _filter = LLMChainFilter.from_llm(model)
  568. retriever = db.as_retriever(search_kwargs={'k': topk})
  569. compression_retriever = ContextualCompressionRetriever(
  570. base_compressor=_filter, base_retriever=retriever
  571. )
  572. retriever_docs = compression_retriever.get_relevant_documents(query)
  573. if long_context:
  574. reordering = LongContextReorder()
  575. retriever_docs = reordering.transform_documents(retriever_docs)
  576. return retriever_docs
  577. @classmethod
  578. def contextual_compression_by_embedding(cls, db, query, embedding_model, topk=5, similarity_threshold=0.76, long_context=False):
  579. """
  580. https://python.langchain.com/docs/modules/data_connection/retrievers/contextual_compression/
  581. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  582. 上下文压缩检索器,embedding 模型,会对结果去重
  583. 使用给定查询的上下文来压缩检索的输出,以便只返回相关信息,而不是立即按照原样返回检索到的文档
  584. 利用 embedding 来计算
  585. :param db:
  586. :param query:
  587. :param embedding_model:
  588. :param topk:
  589. :param long_context: 长上下文排序
  590. :return:
  591. """
  592. retriever = db.as_retriever(search_kwargs={'k': topk})
  593. embeddings_filter = EmbeddingsFilter(embeddings=embedding_model, similarity_threshold=similarity_threshold)
  594. compression_retriever = ContextualCompressionRetriever(
  595. base_compressor=embeddings_filter, base_retriever=retriever
  596. )
  597. retriever_docs = compression_retriever.get_relevant_documents(query)
  598. if long_context:
  599. reordering = LongContextReorder()
  600. retriever_docs = reordering.transform_documents(retriever_docs)
  601. return retriever_docs
  602. @classmethod
  603. def contextual_compression_by_embedding_split(cls, db, query, embedding_model, topk=5, similarity_threshold=0.76, chunk_size=100, chunk_overlap=0, separator=". ", long_context=False):
  604. """
  605. https://python.langchain.com/docs/modules/data_connection/retrievers/contextual_compression/
  606. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  607. 上下文压缩检索器,embedding 模型,会对结果去重,将文档分割成更小的部分
  608. 使用给定查询的上下文来压缩检索的输出,以便只返回相关信息,而不是立即按照原样返回检索到的文档
  609. 利用 embedding 来计算
  610. :param db:
  611. :param query:
  612. :param embedding_model:
  613. :param topk: 不生效,默认是 4 个
  614. :param long_context: 长上下文排序
  615. :return:
  616. """
  617. retriever = db.as_retriever(search_kwargs={'k': topk})
  618. splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator)
  619. redundant_filter = EmbeddingsRedundantFilter(embeddings=embedding_model)
  620. relevant_filter = EmbeddingsFilter(embeddings=embedding_model, similarity_threshold=similarity_threshold)
  621. pipeline_compressor = DocumentCompressorPipeline(
  622. transformers=[splitter, redundant_filter, relevant_filter]
  623. )
  624. compression_retriever = ContextualCompressionRetriever(
  625. base_compressor=pipeline_compressor, base_retriever=retriever
  626. )
  627. retriever_docs = compression_retriever.get_relevant_documents(query)
  628. if long_context:
  629. reordering = LongContextReorder()
  630. retriever_docs = reordering.transform_documents(retriever_docs)
  631. return retriever_docs
  632. @classmethod
  633. def ensemble(cls, query, text_split_docs, embedding_model, bm25_topk=5, topk=5, long_context=False):
  634. """
  635. https://python.langchain.com/docs/modules/data_connection/retrievers/ensemble/
  636. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  637. 混合检索
  638. 最常见的模式是将稀疏检索器(如 BM25)与密集检索器(如嵌入相似性)相结合,因为它们的优势是互补的。它也被称为“混合搜索”。
  639. 稀疏检索器擅长根据关键词查找相关文档,而密集检索器擅长根据语义相似度查找相关文档。
  640. :param query:
  641. :param text_split_docs: langchain 分割后的文档对象
  642. :param long_context: 长上下文排序
  643. :param bm25_topk: bm25 topk
  644. :param topk: 相似性 topk
  645. :return: 会返回两个的并集,结果可能会小于 bm25_topk + topk
  646. """
  647. text_split_docs = [text.page_content for text in text_split_docs]
  648. bm25_retriever = BM25Retriever.from_texts(
  649. text_split_docs, metadatas=[{"source": 1}] * len(text_split_docs)
  650. )
  651. bm25_retriever.k = bm25_topk
  652. faiss_vectorstore = FAISS.from_texts(
  653. text_split_docs, embedding_model, metadatas=[{"source": 2}] * len(text_split_docs)
  654. )
  655. faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": topk})
  656. ensemble_retriever = EnsembleRetriever(
  657. retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
  658. )
  659. retriever_docs = ensemble_retriever.invoke(query)
  660. if long_context:
  661. reordering = LongContextReorder()
  662. retriever_docs = reordering.transform_documents(retriever_docs)
  663. return retriever_docs
  664. @classmethod
  665. def bm25(cls, query, text_split_docs, topk=5, long_context=False):
  666. """
  667. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  668. 稀疏检索器擅长根据关键词查找相关文档
  669. :param query:
  670. :param text_split_docs: langchain 分割后的文档对象
  671. :param topk:
  672. :param long_context: 长上下文压缩
  673. """
  674. text_split_docs = [text.page_content for text in text_split_docs]
  675. bm25_retriever = BM25Retriever.from_texts(
  676. text_split_docs, metadatas=[{"source": 1}] * len(text_split_docs)
  677. )
  678. bm25_retriever.k = topk
  679. retriever_docs = bm25_retriever.get_relevant_documents(query)
  680. if long_context:
  681. reordering = LongContextReorder()
  682. retriever_docs = reordering.transform_documents(retriever_docs)
  683. return retriever_docs
  684. @classmethod
  685. def parent_document_retriever(cls, docs, query, embedding_model):
  686. """
  687. https://python.langchain.com/docs/modules/data_connection/retrievers/parent_document_retriever/
  688. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  689. 父文档检索,只适合,chroma 数据库, faiss 不支持
  690. 适合多个文档加载进来后检索出符合的小文本段,及对应大的 txt
  691. 可以根据此方法,检索出来大的 txt 后,用其他方法再精细化检索 txt 中的内容
  692. :param docs: example
  693. loaders = [
  694. TextLoader("data/专业描述.txt", encoding="utf-8"),
  695. TextLoader("data/专业描述_copy.txt", encoding="utf-8"),
  696. ]
  697. docs = []
  698. for loader in loaders:
  699. docs.extend(loader.load())
  700. :return:
  701. """
  702. child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
  703. vectorstore = Chroma(
  704. collection_name="full_documents", embedding_function=embedding_model
  705. )
  706. store = InMemoryStore()
  707. retriever = ParentDocumentRetriever(
  708. vectorstore=vectorstore,
  709. docstore=store,
  710. child_splitter=child_splitter,
  711. )
  712. retriever.add_documents(docs, ids=None)
  713. sub_docs = vectorstore.similarity_search(query)
  714. parent_docs = retriever.get_relevant_documents(query)
  715. return sub_docs, parent_docs
  716. @classmethod
  717. def tfidf(cls, query, docs_lst, long_context=False):
  718. """
  719. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  720. tfidf 关键词检索
  721. :param query:
  722. :param docs_lst: ['xxx', 'dsfsdg'.....]
  723. :param long_context: 长上下文排序
  724. :return:
  725. """
  726. retriever = TFIDFRetriever.from_texts(docs_lst)
  727. retriever_docs = retriever.get_relevant_documents(query)
  728. if long_context:
  729. reordering = LongContextReorder()
  730. retriever_docs = reordering.transform_documents(retriever_docs)
  731. return retriever_docs
  732. @classmethod
  733. def knn(cls, query, docs_lst, embedding_model, long_context=False):
  734. """
  735. https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder/
  736. knn 检索
  737. :param query:
  738. :param docs_lst: ['xxx', 'dsfsdg'.....]
  739. :param long_context:
  740. :return:
  741. """
  742. retriever = KNNRetriever.from_texts(docs_lst, embedding_model)
  743. retriever_docs = retriever.get_relevant_documents(query)
  744. if long_context:
  745. reordering = LongContextReorder()
  746. retriever_docs = reordering.transform_documents(retriever_docs)
  747. return retriever_docs
  748. class Prompt():
  749. @classmethod
  750. def prompt_template(cls, prompt_string, **kwargs):
  751. """
  752. https://python.langchain.com/docs/modules/model_io/prompts/quick_start/#prompttemplate
  753. 基本 prompt 接受变量的写法,也可以不传入变量
  754. :param prompt_string: 字符串 prompt,变量用 {} 括起来
  755. :param kwargs: 字典,依次传入的变量取值
  756. e.g prompt_string="可以给我介绍一下`{fruit}`还有`{fruit2}`吗?",
  757. fruit='苹果', fruit2='香蕉'
  758. 可以给我介绍一下`苹果`还有`香蕉`吗?
  759. :return:
  760. """
  761. prompt_template = PromptTemplate.from_template(prompt_string)
  762. prompt = prompt_template.format(**kwargs)
  763. return prompt
  764. @classmethod
  765. def chat_prompt_template(cls, text):
  766. """
  767. https://python.langchain.com/docs/modules/model_io/prompts/quick_start/#chatprompttemplate
  768. 对话式模版
  769. content 可以手动设置好,每次传入人工的提示词 text
  770. :param text:
  771. :return:
  772. """
  773. chat_template = ChatPromptTemplate.from_messages(
  774. [
  775. SystemMessage(
  776. content=(
  777. "You are a helpful assistant that re-writes the user's text to "
  778. "sound more upbeat."
  779. )
  780. ),
  781. HumanMessagePromptTemplate.from_template("{text}"),
  782. ]
  783. )
  784. messages = chat_template.format_messages(text=text)
  785. return messages
  786. @classmethod
  787. def chat_message_prompt_template(cls, prompt_string, role='human', **kwargs):
  788. """
  789. https://python.langchain.com/docs/modules/model_io/prompts/quick_start/#message-prompts
  790. 聊天模型支持以任意角色获取聊天消息,您可以使用ChatMessagePromptTemplate,它允许用户指定角色名称
  791. :param prompt_string:
  792. :param role: 指定的角色
  793. :param kwargs:
  794. :return:
  795. """
  796. chat_message_prompt = ChatMessagePromptTemplate.from_template(
  797. role=role, template=prompt_string
  798. )
  799. message = chat_message_prompt.format(**kwargs)
  800. return message
  801. @classmethod
  802. def messages_placeholder(cls, human_prompt, **kwargs):
  803. """
  804. https://python.langchain.com/docs/modules/model_io/prompts/quick_start/#messagesplaceholder
  805. 可以让您完全控制格式化期间要呈现的消息。当您不确定消息提示模板应使用什么角色或希望在格式化期间插入消息列表时,这会很有用
  806. content 可以手动定义
  807. :param human_prompt:
  808. :param kwargs: prompt 变量
  809. :return:
  810. """
  811. human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
  812. chat_prompt = ChatPromptTemplate.from_messages(
  813. [MessagesPlaceholder(variable_name="conversation"), human_message_template]
  814. )
  815. human_message = HumanMessage(content="What is the best way to learn programming?")
  816. ai_message = AIMessage(
  817. content="""\
  818. 1. Choose a programming language: Decide on a programming language that you want to learn.
  819. 2. Start with the basics: Familiarize yourself with the basic programming concepts such as variables, data types and control structures.
  820. 3. Practice, practice, practice: The best way to learn programming is through hands-on experience\
  821. """
  822. )
  823. message = chat_prompt.format_prompt(
  824. conversation=[human_message, ai_message], **kwargs
  825. ).to_messages()
  826. return message
  827. @classmethod
  828. def example_selectors_length_based(cls, examples, string, max_length=25):
  829. """
  830. https://python.langchain.com/docs/modules/model_io/prompts/example_selectors/length_based/
  831. 按 prompt 长度选择示例
  832. prefix 可以定义,是显示在开头的
  833. :param examples: 示例列表
  834. e.g examples = [
  835. {"input": "happy", "output": "sad"},
  836. {"input": "tall", "output": "short"},
  837. {"input": "energetic", "output": "lethargic"},
  838. {"input": "sunny", "output": "gloomy"},
  839. {"input": "windy", "output": "calm"},
  840. ]
  841. :param examples: 传进来的 prompt
  842. :param max_length: 传进来的 prompt 最大长度小于它则选择全部示例,否则根据长度自动选择几个示例
  843. :return:
  844. """
  845. examples = [
  846. {"input": "happy", "output": "sad"},
  847. {"input": "tall", "output": "short"},
  848. {"input": "energetic", "output": "lethargic"},
  849. {"input": "sunny", "output": "gloomy"},
  850. {"input": "windy", "output": "calm"},
  851. ]
  852. example_prompt = PromptTemplate(
  853. input_variables=["input", "output"],
  854. template="Input: {input}\nOutput: {output}",
  855. )
  856. example_selector = LengthBasedExampleSelector(
  857. # The examples it has available to choose from.
  858. examples=examples,
  859. # The PromptTemplate being used to format the examples.
  860. example_prompt=example_prompt,
  861. max_length=max_length,
  862. )
  863. dynamic_prompt = FewShotPromptTemplate(
  864. # We provide an ExampleSelector instead of examples.
  865. example_selector=example_selector,
  866. example_prompt=example_prompt,
  867. prefix="Give the antonym of every input",
  868. suffix="Input: {adjective}\nOutput:",
  869. input_variables=["adjective"],
  870. )
  871. example_prompt = dynamic_prompt.format(adjective=string)
  872. return example_prompt
  873. @classmethod
  874. def example_selectors_by_mmr(cls, examples, string, embedding_model, k=2):
  875. """
  876. https://python.langchain.com/docs/modules/model_io/prompts/example_selectors/mmr/
  877. 根据与输入最相似的示例的组合来选择示例,同时还针对多样性进行优化。
  878. 它通过查找与输入具有最大余弦相似度的嵌入示例来实现这一点,然后迭代地添加它们,同时惩罚它们与已选择示例的接近程度
  879. 总的来说就是选出的每个示例尽量都不相相似,不重复
  880. prefix 可以自己定义,显示在开头的
  881. :param examples: 示例列表
  882. e.g examples = [
  883. {"input": "happy", "output": "sad"},
  884. {"input": "tall", "output": "short"},
  885. {"input": "energetic", "output": "lethargic"},
  886. {"input": "sunny", "output": "gloomy"},
  887. {"input": "windy", "output": "calm"},
  888. ]
  889. :param string: prompt 字符串
  890. :param examples:
  891. :param embedding_model:
  892. :param k: 选几个示例
  893. :return:
  894. """
  895. examples = [
  896. {"input": "happy", "output": "sad"},
  897. {"input": "tall", "output": "short"},
  898. {"input": "energetic", "output": "lethargic"},
  899. {"input": "sunny", "output": "gloomy"},
  900. {"input": "windy", "output": "calm"},
  901. ]
  902. example_prompt = PromptTemplate(
  903. input_variables=["input", "output"],
  904. template="Input: {input}\nOutput: {output}",
  905. )
  906. example_selector = MaxMarginalRelevanceExampleSelector.from_examples(
  907. examples, embedding_model, FAISS, k=k
  908. )
  909. mmr_prompt = FewShotPromptTemplate(
  910. # We provide an ExampleSelector instead of examples.
  911. example_selector=example_selector,
  912. example_prompt=example_prompt,
  913. prefix="Give the antonym of every input",
  914. suffix="Input: {adjective}\nOutput:",
  915. input_variables=["adjective"],
  916. )
  917. mmr_prompt = mmr_prompt.format(adjective=string)
  918. return mmr_prompt
  919. @classmethod
  920. def example_selectors_similarity(cls, examples, string, embedding_model, k=1):
  921. """
  922. https://python.langchain.com/docs/modules/model_io/prompts/example_selectors/similarity/
  923. 该对象根据与输入的相似性来选择示例。它通过查找与输入具有最大余弦相似度的嵌入示例来实现这一点。
  924. 会选择跟 string 相似的示例
  925. :param examples: 示例列表
  926. e.g examples = [
  927. {"input": "happy", "output": "sad"},
  928. {"input": "tall", "output": "short"},
  929. {"input": "energetic", "output": "lethargic"},
  930. {"input": "sunny", "output": "gloomy"},
  931. {"input": "windy", "output": "calm"},
  932. ]
  933. :param string:
  934. :param embedding_model:
  935. :param k: 选择几个
  936. :return:
  937. """
  938. examples = [
  939. {"input": "happy", "output": "sad"},
  940. {"input": "tall", "output": "short"},
  941. {"input": "energetic", "output": "lethargic"},
  942. {"input": "sunny", "output": "gloomy"},
  943. {"input": "windy", "output": "calm"},
  944. ]
  945. example_prompt = PromptTemplate(
  946. input_variables=["input", "output"],
  947. template="Input: {input}\nOutput: {output}",
  948. )
  949. example_selector = SemanticSimilarityExampleSelector.from_examples(
  950. examples, embedding_model, Chroma, k=k,
  951. )
  952. similar_prompt = FewShotPromptTemplate(
  953. # We provide an ExampleSelector instead of examples.
  954. example_selector=example_selector,
  955. example_prompt=example_prompt,
  956. prefix="Give the antonym of every input",
  957. suffix="Input: {adjective}\nOutput:",
  958. input_variables=["adjective"],
  959. )
  960. similar_prompt = similar_prompt.format(adjective=string)
  961. return similar_prompt
  962. @classmethod
  963. def few_shot_examples_chat(cls, examples, string, model):
  964. """
  965. https://python.langchain.com/docs/modules/model_io/prompts/few_shot_examples_chat/#fixed-examples
  966. 适用于 chat 模型
  967. system 可以手动设置
  968. :param examples:
  969. e.g examples = [
  970. {"input": "2+2", "output": "4"},
  971. {"input": "2+3", "output": "5"},
  972. ]
  973. :param string:
  974. :param model: 大模型
  975. :return:
  976. """
  977. examples = [
  978. {"input": "2+2", "output": "4"},
  979. {"input": "2+3", "output": "5"},
  980. ]
  981. example_prompt = ChatPromptTemplate.from_messages(
  982. [
  983. ("human", "{input}"),
  984. ("ai", "{output}"),
  985. ]
  986. )
  987. few_shot_prompt = FewShotChatMessagePromptTemplate(
  988. example_prompt=example_prompt,
  989. examples=examples,
  990. )
  991. few_shot_prompt = few_shot_prompt.format()
  992. final_prompt = ChatPromptTemplate.from_messages(
  993. [
  994. ("system", "You are a wondrous wizard of math."),
  995. few_shot_prompt,
  996. ("human", "{input}"),
  997. ]
  998. )
  999. chain = final_prompt | model
  1000. res = chain.invoke({"input": string})
  1001. return res
  1002. @classmethod
  1003. def few_shot_examples(cls, examples, string, embedding_model, k=1):
  1004. """
  1005. https://python.langchain.com/docs/modules/model_io/prompts/few_shot_examples/#create-the-example-set
  1006. 根据与输入的相似性来选择少数样本。它使用嵌入模型来计算输入和少数样本之间的相似度,并使用向量存储来执行最近邻搜索。
  1007. :param examples: 列表,参照 few_shot_examples_chat 样例
  1008. :param string:
  1009. :param embedding_model:
  1010. :return:
  1011. """
  1012. examples = [
  1013. {
  1014. "question": "Who lived longer, Muhammad Ali or Alan Turing?",
  1015. "answer": """
  1016. Are follow up questions needed here: Yes.
  1017. Follow up: How old was Muhammad Ali when he died?
  1018. Intermediate answer: Muhammad Ali was 74 years old when he died.
  1019. Follow up: How old was Alan Turing when he died?
  1020. Intermediate answer: Alan Turing was 41 years old when he died.
  1021. So the final answer is: Muhammad Ali
  1022. """,
  1023. }
  1024. ]
  1025. example_prompt = PromptTemplate(
  1026. input_variables=["question", "answer"], template="Question: {question}\n{answer}"
  1027. )
  1028. example_selector = SemanticSimilarityExampleSelector.from_examples(
  1029. examples, embedding_model, Chroma, k=k,
  1030. )
  1031. prompt = FewShotPromptTemplate(
  1032. example_selector=example_selector,
  1033. example_prompt=example_prompt,
  1034. suffix="Question: {input}",
  1035. input_variables=["input"],
  1036. )
  1037. prompt = prompt.format(input=string)
  1038. return prompt
  1039. class Chain():
  1040. @classmethod
  1041. def base_llm_chain(cls, model, prompt, **kwargs):
  1042. """
  1043. https://python.langchain.com/docs/modules/model_io/prompts/composition/#string-prompt-composition
  1044. 基础链,带有变量的 prompt ,model 两个组成链
  1045. :param model: llm
  1046. :param prompt: prompt 其中的变量是用 {} 括起来的
  1047. :param kwargs: prompt 中的变量
  1048. :return:
  1049. """
  1050. prompt = PromptTemplate.from_template(prompt)
  1051. chain = LLMChain(llm=model, prompt=prompt)
  1052. res = chain.run(kwargs)
  1053. return res
  1054. @classmethod
  1055. def batch_base_llm_chain(cls, model, prompt, max_concurrency=5, **kwargs):
  1056. """
  1057. https://python.langchain.com/docs/modules/model_io/prompts/composition/#string-prompt-composition
  1058. 基础链,带有变量的 prompt ,model 两个组成链,批次调用
  1059. :param model: llm
  1060. :param prompt: prompt 其中的变量是用 {} 括起来的
  1061. :param kwargs: prompt 中的变量
  1062. :param max_concurrency: 并发请求数
  1063. e.g:
  1064. promt = 'tell me a joke about {other} and {topic2}'
  1065. other = ['bear', 'dog']
  1066. topic2 = ['cat', 'monkey']
  1067. 传进来后的 kwargs: kwargs = {'topic1': ['bear', 'dog'], 'topic2': ['cat', 'monkey']}
  1068. 处理后 batch_list: batch_list = [{"topic1": "bears", "topic2": "cat"}, {"topic1": "dog", "topic2": "monkey"}]
  1069. :return:
  1070. """
  1071. prompt = PromptTemplate.from_template(prompt)
  1072. chain = LLMChain(llm=model, prompt=prompt)
  1073. # 确保所有列表长度相同,构造批次列表
  1074. keys = list(kwargs.keys())
  1075. first_list_length = len(kwargs[keys[0]])
  1076. if all(len(kwargs[key]) == first_list_length for key in keys):
  1077. # 使用zip函数将所有值配对
  1078. paired_values = zip(*[kwargs[key] for key in keys])
  1079. # 遍历配对后的值,构造新的字典列表
  1080. batch_list = [dict(zip(keys, values)) for values in paired_values]
  1081. else:
  1082. print("批次对应列表长度不一致,无法转换。")
  1083. return None
  1084. res = chain.batch(batch_list, config={"max_concurrency": max_concurrency})
  1085. return res
  1086. @classmethod
  1087. def base_chat_llm_chain(cls, model, inputs, **kwargs):
  1088. """
  1089. https://python.langchain.com/docs/modules/model_io/prompts/composition/#string-prompt-composition
  1090. 基础链,对话模型 prompt ,model 两个组成链
  1091. :param model:
  1092. :param input: 输入
  1093. :param kwargs: 可以带一些变量
  1094. :return:
  1095. """
  1096. prompt = SystemMessage(content="你是个智能助手,能回答各种各样的问题。")
  1097. new_prompt = (
  1098. prompt + HumanMessage(content="hi") + AIMessage(content="what?") + "{input}"
  1099. )
  1100. new_prompt.format_messages(input="i said hi")
  1101. chain = LLMChain(llm=model, prompt=new_prompt)
  1102. res = chain.run(inputs)
  1103. return res
  1104. @classmethod
  1105. def csv_parser_chain(cls, prompt_string, model, **kwargs):
  1106. """
  1107. https://python.langchain.com/docs/modules/model_io/output_parsers/types/csv/
  1108. 列表格式
  1109. :param prompt_string: prompt 字符串,里面变量使用 {} 括起来
  1110. :param model: llm
  1111. :param kwargs: 字典变量
  1112. :return:
  1113. """
  1114. output_parser = CommaSeparatedListOutputParser()
  1115. format_instructions = output_parser.get_format_instructions()
  1116. kwargs['format_instructions'] = format_instructions # 格式化输出设置
  1117. prompt = PromptTemplate(
  1118. template=prompt_string + "\n{format_instructions}",
  1119. input_variables=[],
  1120. partial_variables=kwargs, # 变量赋值
  1121. )
  1122. chain = prompt | model | output_parser
  1123. res = chain.invoke({})
  1124. return res
  1125. @classmethod
  1126. def datetime_parser_chain(cls, prompt_string, model, **kwargs):
  1127. """
  1128. https://python.langchain.com/docs/modules/model_io/output_parsers/types/datetime/
  1129. 输出时间格式,2009-01-03 18:15:05
  1130. :param prompt_string: prompt 字符串,里面变量使用 {} 括起来
  1131. :param model: llm
  1132. :param kwargs: 字典变量
  1133. :return:
  1134. """
  1135. output_parser = DatetimeOutputParser()
  1136. template = prompt_string + """{format_instructions}"""
  1137. kwargs['format_instructions'] = output_parser.get_format_instructions() # 设置输出格式
  1138. prompt = PromptTemplate.from_template(
  1139. template,
  1140. partial_variables=kwargs, # 设置所有变量
  1141. )
  1142. chain = prompt | model | output_parser
  1143. output = chain.invoke({})
  1144. return output
  1145. @classmethod
  1146. def json_parser_chain(cls, prompt_string, model, json_class=None, **kwargs):
  1147. """
  1148. https://python.langchain.com/docs/modules/model_io/output_parsers/types/json/
  1149. json
  1150. :param prompt_string: prompt 字符串,里面变量是已经填充好的
  1151. :param model: llm
  1152. :param enum_class: json 类,用来指定输出字典的键,也可以不用指定,这样默认就一个键
  1153. from langchain_core.pydantic_v1 import BaseModel, Field
  1154. e.g class Joke(BaseModel):
  1155. setup: str = Field(description="question to set up a joke")
  1156. punchline: str = Field(description="answer to resolve the joke")
  1157. :param kwargs: 字典变量
  1158. :return:
  1159. """
  1160. parser = JsonOutputParser(pydantic_object=json_class)
  1161. format_instructions = parser.get_format_instructions()
  1162. kwargs['format_instructions'] = format_instructions # 格式化输出设置
  1163. kwargs['prompt_string'] = prompt_string
  1164. prompt = PromptTemplate(
  1165. template="Answer the user query.\n{format_instructions}\n{prompt_string}\n",
  1166. input_variables=[],
  1167. partial_variables=kwargs, # 设置所有变量
  1168. )
  1169. chain = prompt | model | parser
  1170. res = chain.invoke({})
  1171. return res
  1172. class Tools():
  1173. @classmethod
  1174. def python_repl_tool(cls, code):
  1175. """
  1176. https://python.langchain.com/docs/integrations/tools/python/
  1177. 可以执行 python 代码,但是注意缩进
  1178. :param code:
  1179. :return:
  1180. """
  1181. python_repl = PythonREPL()
  1182. res = python_repl.run(code)
  1183. return res
  1184. # # You can create the tool to pass to an agent
  1185. # repl_tool = Tool(
  1186. # name="python_repl",
  1187. # description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.",
  1188. # func=python_repl.run,
  1189. # )
  1190. @classmethod
  1191. def requests_get_tool(cls, url):
  1192. """
  1193. https://python.langchain.com/docs/integrations/tools/requests/
  1194. 可能有乱码,好像没有指定编码的参数
  1195. requests_tools 包含以下包装器
  1196. [RequestsGetTool(name='requests_get', description='A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request.', args_schema=None, return_direct=False, verbose=False, callbacks=None, callback_manager=None, requests_wrapper=TextRequestsWrapper(headers=None, aiosession=None)),
  1197. RequestsPostTool(name='requests_post', description='Use this when you want to POST to a website.\n Input should be a json string with two keys: "url" and "data".\n The value of "url" should be a string, and the value of "data" should be a dictionary of \n key-value pairs you want to POST to the url.\n Be careful to always use double quotes for strings in the json string\n The output will be the text response of the POST request.\n ', args_schema=None, return_direct=False, verbose=False, callbacks=None, callback_manager=None, requests_wrapper=TextRequestsWrapper(headers=None, aiosession=None)),
  1198. RequestsPatchTool(name='requests_patch', description='Use this when you want to PATCH to a website.\n Input should be a json string with two keys: "url" and "data".\n The value of "url" should be a string, and the value of "data" should be a dictionary of \n key-value pairs you want to PATCH to the url.\n Be careful to always use double quotes for strings in the json string\n The output will be the text response of the PATCH request.\n ', args_schema=None, return_direct=False, verbose=False, callbacks=None, callback_manager=None, requests_wrapper=TextRequestsWrapper(headers=None, aiosession=None)),
  1199. RequestsPutTool(name='requests_put', description='Use this when you want to PUT to a website.\n Input should be a json string with two keys: "url" and "data".\n The value of "url" should be a string, and the value of "data" should be a dictionary of \n key-value pairs you want to PUT to the url.\n Be careful to always use double quotes for strings in the json string.\n The output will be the text response of the PUT request.\n ', args_schema=None, return_direct=False, verbose=False, callbacks=None, callback_manager=None, requests_wrapper=TextRequestsWrapper(headers=None, aiosession=None)),
  1200. RequestsDeleteTool(name='requests_delete', description='A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request.', args_schema=None, return_direct=False, verbose=False, callbacks=None, callback_manager=None, requests_wrapper=TextRequestsWrapper(headers=None, aiosession=None))]
  1201. :param url:
  1202. :return:
  1203. """
  1204. # requests_tools = load_tools(["requests_all"])
  1205. # # Each tool wrapps a requests wrapper
  1206. # requests_tools[0].requests_wrapper
  1207. requests = TextRequestsWrapper()
  1208. res = requests.get(url)
  1209. return res
  1210. if __name__ == '__main__':
  1211. from pprint import pprint
  1212. os.environ['TRANSFORMERS_OFFLINE'] = "1"
  1213. # file_name = '浙江国迈建设集团有限公司技术文件'
  1214. file_name = '北京华科同安监控技术有限公司_textmind'
  1215. file_name = '中科时代(北京)科技有限公司_textmind'
  1216. file_name = '中能拾贝(广州)科技有限公司_textmind'
  1217. file_name = '安徽德通智联科技有限公司_textmind'
  1218. file_name = '旷智中科(北京)技术有限公司_textmind'
  1219. file_name = '武汉大学_textmind'
  1220. file_name = '武汉理工大学_textmind'
  1221. file_name = '河海大学_textmind'
  1222. # file_type = 'md'
  1223. file_type = 'txt'
  1224. # file_type = 'json'
  1225. filepath = f'D:\\desktop\\三峡水利\\data\\0预审查初审详审测试数据\\textmind_result\\{file_name}.{file_type}'
  1226. # documents = DocsLoader.markdown_loader(filepath=filepath, mode='elements')
  1227. documents = DocsLoader.textmind_loader(filepath=filepath)
  1228. # raw_doc = open(filepath, 'r', encoding='utf-8').read()
  1229. # documents = TextSpliter.markdown_split(mkardown_string=raw_doc, char_level_splits=True)
  1230. # # print(documents)
  1231. # embedding = EmbeddingVectorDB.load_local_embedding_model(embedding_model_path='BAAI/bge-small-zh-v1.5')
  1232. embedding = EmbeddingVectorDB.load_local_embedding_model(embedding_model_path='GanymedeNil/text2vec-base-chinese')
  1233. db = EmbeddingVectorDB.chroma_vector_db(split_docs=documents, vector_db_path=f'./chromadb/{file_name}/', embedding_model=embedding)
  1234. # db = EmbeddingVectorDB.faiss_vector_db(split_docs=documents, vector_db_path=f'./faissdb/{file_name}/', embedding_model=embedding)
  1235. # db = EmbeddingVectorDB.chroma_vector_db(split_docs=None, vector_db_path=f'./chromadb/{file_name}/', embedding_model=embedding)
  1236. # db = EmbeddingVectorDB.faiss_vector_db(split_docs=None, vector_db_path=f'./faissdb/{file_name}/', embedding_model=embedding)
  1237. query = '净利润|利润总额'
  1238. # query = '类似项目业绩|项目合同|项目规模|项目名称'
  1239. # query = '报价表|报价清单|分项报价表'
  1240. # 相似度搜索
  1241. docs = Retriever.similarity(db, query, topk=3, long_context=True)
  1242. # # 传入向量去搜索
  1243. # embedding_vector = embedding.embed_query(query)
  1244. # docs = db.similarity_search_by_vector(embedding_vector, k=3)
  1245. # docs = Retriever.mmr(db=db, query=query)
  1246. # docs = Retriever.similarity_with_score(db=db, query=query, topk=3, long_context=False)
  1247. # docs = Retriever.similarity_score_threshold(db=db, query=query)
  1248. for doc in docs:
  1249. pprint(doc)