# -*- coding: utf-8 -*- # @Author: privacy # @Date: 2024-06-11 13:43:14 # @Last Modified by: privacy # @Last Modified time: 2024-12-06 17:00:50 import hashlib from typing import List, Optional import chromadb # from chromadb.api.types import Documnets, EmbeddingFunction, Embeddings # class MyEmbeddingFunction(EmbeddingFunction): # def __call__(self, texts: documents) -> Embeddings: # embeddings = [model.encode(x) for x in texts] # return embeddings class VectorClient: def __init__(self): super().__init__() @classmethod def add(cls, pages_content: List[dict], project_name: Optional[str], file_name: Optional[str]) -> None: """ 将文件的每页正文转换为向量数据 Args: pages_content: 每页的正文内容 file_name: 文件名 Returns: None """ col_name = project_name + '_' + file_name name = hashlib.md5(col_name.encode(encoding="utf-8")).hexdigest() # 初始化 Chroma 客户端,连接到服务器 # client = await chromadb.AsyncHttpClient(host='localhost', port=50000) client = chromadb.HttpClient(host='localhost', port=50000) # 使用集合(Collections)创建集合 # collection = await client..get_or_create_collection( collection = client.get_or_create_collection( name=name, # embedding_function=emb_fn ) # 向集合添加数据 for index, item in enumerate(pages_content): # await collection.add( collection.add( documents=[ item.get('text') ], metadatas=[ {"page_number": item.get('page_number')} ], ids=[ str(index) ], # embeddings=[ # [1.1, 2.3, 3.2] # ] ) @classmethod def query(cls, project_name: Optional[str], file_name: Optional[str], query: List[str] = ["资质条件"], contains: Optional[str] = None): """ 查询向量数据库 Args: pages_content: 每页的正文内容 file_name: 文件名 Returns: None """ col_name = project_name + '_' + file_name name = hashlib.md5(col_name.encode(encoding="utf-8")).hexdigest() # 初始化 Chroma 客户端,连接到服务器 # client = await chromadb.AsyncHttpClient(host='localhost', port=50000) client = chromadb.HttpClient(host='localhost', port=50000) # collection = await client.get_collection( collection = client.get_collection( name=name, # embedding_function=emb_fn ) if contains: # 查询集合 # results = await collection.query( results = collection.query( query_texts=query, n_results=5, # include=[], # where = {"metadata_field": "is_equal_to_this"}, where_document={"$contains": contains} ) else: # results = await collection.query( results = collection.query( query_texts=query, n_results=5 ) return results if __name__ == '__main__': import os import json import asyncio from glob import glob from pprint import pprint base_dir = 'D:\\desktop\\三峡水利\\data\\projects' base_dir = 'D:\\desktop\\三峡水利\\data\\0预审查初审详审测试数据' proj_name = '三峡左岸及地下电站地坪整治' proj_name = '【231100334287】500kV交流GIS GIL固体绝缘沿面失效机理研究及综合诊断装备研制采购程序文件' for supplierdir in glob(os.path.join(os.path.join(base_dir, proj_name), '*')): for file in glob(os.path.join(supplierdir, '*-content.json')): supplier = file.split('\\')[-2] # with open(file, 'r', encoding='utf-8') as fp: # content = json.load(fp) # # asyncio.run(insert(pages_content = content, project_name=proj_name, file_name=supplier)) # asyncio.run( # search( # project_name='三峡左岸及地下电站地坪整治', # file_name='湖北建新建设工程有限公司', # query=['财务审计报告'], # contains='2021年' # ) # ) # VectorClient.add( # pages_content=content, # project_name=proj_name, # file_name=supplier # ) # r = VectorClient.query( # project_name=proj_name, # file_name=supplier, # # query=['查询设备特性及性能组织'], # query=['证书'], # # contains='设备性能' # contains='二级' # ) # pprint(r) # r = VectorClient.query( # project_name=proj_name, # file_name='supplier', # query=['证书'], # contains='营业执照' # ) # pprint(r)