123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- # -*- coding: utf-8 -*-
- # @Author: privacy
- # @Date: 2024-06-11 13:43:14
- # @Last Modified by: privacy
- # @Last Modified time: 2024-12-06 17:00:50
- import hashlib
- from typing import List, Optional
- import chromadb
- # from chromadb.api.types import Documnets, EmbeddingFunction, Embeddings
- # class MyEmbeddingFunction(EmbeddingFunction):
- # def __call__(self, texts: documents) -> Embeddings:
- # embeddings = [model.encode(x) for x in texts]
- # return embeddings
- class VectorClient:
- def __init__(self):
- super().__init__()
- @classmethod
- def add(cls, pages_content: List[dict], project_name: Optional[str], file_name: Optional[str]) -> None:
- """
- 将文件的每页正文转换为向量数据
- Args:
- pages_content: 每页的正文内容
- file_name: 文件名
- Returns:
- None
- """
- col_name = project_name + '_' + file_name
- name = hashlib.md5(col_name.encode(encoding="utf-8")).hexdigest()
- # 初始化 Chroma 客户端,连接到服务器
- # client = await chromadb.AsyncHttpClient(host='localhost', port=50000)
- client = chromadb.HttpClient(host='localhost', port=50000)
- # 使用集合(Collections)创建集合
- # collection = await client..get_or_create_collection(
- collection = client.get_or_create_collection(
- name=name,
- # embedding_function=emb_fn
- )
- # 向集合添加数据
- for index, item in enumerate(pages_content):
- # await collection.add(
- collection.add(
- documents=[
- item.get('text')
- ],
- metadatas=[
- {"page_number": item.get('page_number')}
- ],
- ids=[
- str(index)
- ],
- # embeddings=[
- # [1.1, 2.3, 3.2]
- # ]
- )
- @classmethod
- def query(cls, project_name: Optional[str], file_name: Optional[str], query: List[str] = ["资质条件"], contains: Optional[str] = None):
- """
- 查询向量数据库
- Args:
- pages_content: 每页的正文内容
- file_name: 文件名
- Returns:
- None
- """
- col_name = project_name + '_' + file_name
- name = hashlib.md5(col_name.encode(encoding="utf-8")).hexdigest()
- # 初始化 Chroma 客户端,连接到服务器
- # client = await chromadb.AsyncHttpClient(host='localhost', port=50000)
- client = chromadb.HttpClient(host='localhost', port=50000)
- # collection = await client.get_collection(
- collection = client.get_collection(
- name=name,
- # embedding_function=emb_fn
- )
- if contains:
- # 查询集合
- # results = await collection.query(
- results = collection.query(
- query_texts=query,
- n_results=5,
- # include=[],
- # where = {"metadata_field": "is_equal_to_this"},
- where_document={"$contains": contains}
- )
- else:
- # results = await collection.query(
- results = collection.query(
- query_texts=query,
- n_results=5
- )
- return results
- if __name__ == '__main__':
- import os
- import json
- import asyncio
- from glob import glob
- from pprint import pprint
- base_dir = 'D:\\desktop\\三峡水利\\data\\projects'
- base_dir = 'D:\\desktop\\三峡水利\\data\\0预审查初审详审测试数据'
- proj_name = '三峡左岸及地下电站地坪整治'
- proj_name = '【231100334287】500kV交流GIS GIL固体绝缘沿面失效机理研究及综合诊断装备研制采购程序文件'
- for supplierdir in glob(os.path.join(os.path.join(base_dir, proj_name), '*')):
- for file in glob(os.path.join(supplierdir, '*-content.json')):
- supplier = file.split('\\')[-2]
- # with open(file, 'r', encoding='utf-8') as fp:
- # content = json.load(fp)
- # # asyncio.run(insert(pages_content = content, project_name=proj_name, file_name=supplier))
- # asyncio.run(
- # search(
- # project_name='三峡左岸及地下电站地坪整治',
- # file_name='湖北建新建设工程有限公司',
- # query=['财务审计报告'],
- # contains='2021年'
- # )
- # )
- # VectorClient.add(
- # pages_content=content,
- # project_name=proj_name,
- # file_name=supplier
- # )
- # r = VectorClient.query(
- # project_name=proj_name,
- # file_name=supplier,
- # # query=['查询设备特性及性能组织'],
- # query=['证书'],
- # # contains='设备性能'
- # contains='二级'
- # )
- # pprint(r)
- # r = VectorClient.query(
- # project_name=proj_name,
- # file_name='supplier',
- # query=['证书'],
- # contains='营业执照'
- # )
- # pprint(r)
|