trans_file2vec.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # -*- coding: utf-8 -*-
  2. # @Author: privacy
  3. # @Date: 2024-06-11 13:43:14
  4. # @Last Modified by: privacy
  5. # @Last Modified time: 2024-12-06 17:00:50
  6. import hashlib
  7. from typing import List, Optional
  8. import chromadb
  9. # from chromadb.api.types import Documnets, EmbeddingFunction, Embeddings
  10. # class MyEmbeddingFunction(EmbeddingFunction):
  11. # def __call__(self, texts: documents) -> Embeddings:
  12. # embeddings = [model.encode(x) for x in texts]
  13. # return embeddings
  14. class VectorClient:
  15. def __init__(self):
  16. super().__init__()
  17. @classmethod
  18. def add(cls, pages_content: List[dict], project_name: Optional[str], file_name: Optional[str]) -> None:
  19. """
  20. 将文件的每页正文转换为向量数据
  21. Args:
  22. pages_content: 每页的正文内容
  23. file_name: 文件名
  24. Returns:
  25. None
  26. """
  27. col_name = project_name + '_' + file_name
  28. name = hashlib.md5(col_name.encode(encoding="utf-8")).hexdigest()
  29. # 初始化 Chroma 客户端,连接到服务器
  30. # client = await chromadb.AsyncHttpClient(host='localhost', port=50000)
  31. client = chromadb.HttpClient(host='localhost', port=50000)
  32. # 使用集合(Collections)创建集合
  33. # collection = await client..get_or_create_collection(
  34. collection = client.get_or_create_collection(
  35. name=name,
  36. # embedding_function=emb_fn
  37. )
  38. # 向集合添加数据
  39. for index, item in enumerate(pages_content):
  40. # await collection.add(
  41. collection.add(
  42. documents=[
  43. item.get('text')
  44. ],
  45. metadatas=[
  46. {"page_number": item.get('page_number')}
  47. ],
  48. ids=[
  49. str(index)
  50. ],
  51. # embeddings=[
  52. # [1.1, 2.3, 3.2]
  53. # ]
  54. )
  55. @classmethod
  56. def query(cls, project_name: Optional[str], file_name: Optional[str], query: List[str] = ["资质条件"], contains: Optional[str] = None):
  57. """
  58. 查询向量数据库
  59. Args:
  60. pages_content: 每页的正文内容
  61. file_name: 文件名
  62. Returns:
  63. None
  64. """
  65. col_name = project_name + '_' + file_name
  66. name = hashlib.md5(col_name.encode(encoding="utf-8")).hexdigest()
  67. # 初始化 Chroma 客户端,连接到服务器
  68. # client = await chromadb.AsyncHttpClient(host='localhost', port=50000)
  69. client = chromadb.HttpClient(host='localhost', port=50000)
  70. # collection = await client.get_collection(
  71. collection = client.get_collection(
  72. name=name,
  73. # embedding_function=emb_fn
  74. )
  75. if contains:
  76. # 查询集合
  77. # results = await collection.query(
  78. results = collection.query(
  79. query_texts=query,
  80. n_results=5,
  81. # include=[],
  82. # where = {"metadata_field": "is_equal_to_this"},
  83. where_document={"$contains": contains}
  84. )
  85. else:
  86. # results = await collection.query(
  87. results = collection.query(
  88. query_texts=query,
  89. n_results=5
  90. )
  91. return results
  92. if __name__ == '__main__':
  93. import os
  94. import json
  95. import asyncio
  96. from glob import glob
  97. from pprint import pprint
  98. base_dir = 'D:\\desktop\\三峡水利\\data\\projects'
  99. base_dir = 'D:\\desktop\\三峡水利\\data\\0预审查初审详审测试数据'
  100. proj_name = '三峡左岸及地下电站地坪整治'
  101. proj_name = '【231100334287】500kV交流GIS GIL固体绝缘沿面失效机理研究及综合诊断装备研制采购程序文件'
  102. for supplierdir in glob(os.path.join(os.path.join(base_dir, proj_name), '*')):
  103. for file in glob(os.path.join(supplierdir, '*-content.json')):
  104. supplier = file.split('\\')[-2]
  105. # with open(file, 'r', encoding='utf-8') as fp:
  106. # content = json.load(fp)
  107. # # asyncio.run(insert(pages_content = content, project_name=proj_name, file_name=supplier))
  108. # asyncio.run(
  109. # search(
  110. # project_name='三峡左岸及地下电站地坪整治',
  111. # file_name='湖北建新建设工程有限公司',
  112. # query=['财务审计报告'],
  113. # contains='2021年'
  114. # )
  115. # )
  116. # VectorClient.add(
  117. # pages_content=content,
  118. # project_name=proj_name,
  119. # file_name=supplier
  120. # )
  121. # r = VectorClient.query(
  122. # project_name=proj_name,
  123. # file_name=supplier,
  124. # # query=['查询设备特性及性能组织'],
  125. # query=['证书'],
  126. # # contains='设备性能'
  127. # contains='二级'
  128. # )
  129. # pprint(r)
  130. # r = VectorClient.query(
  131. # project_name=proj_name,
  132. # file_name='supplier',
  133. # query=['证书'],
  134. # contains='营业执照'
  135. # )
  136. # pprint(r)