Browse Source

2024-12-06

sprivacy 8 months ago
parent
commit
a838bb3143
1 changed files with 57 additions and 119 deletions
  1. 57 119
      celery_tasks/trans_file2vec.py

+ 57 - 119
celery_tasks/trans_file2vec.py

@@ -2,11 +2,10 @@
 # @Author: privacy
 # @Author: privacy
 # @Date:   2024-06-11 13:43:14
 # @Date:   2024-06-11 13:43:14
 # @Last Modified by:   privacy
 # @Last Modified by:   privacy
-# @Last Modified time: 2024-12-04 11:22:35
+# @Last Modified time: 2024-12-06 17:00:50
 
 
 import hashlib
 import hashlib
-from abc import ABC, abstractmethod
-from typing import Dict, List, Optional, Union
+from typing import List, Optional
 
 
 import chromadb
 import chromadb
 # from chromadb.api.types import Documnets, EmbeddingFunction, Embeddings
 # from chromadb.api.types import Documnets, EmbeddingFunction, Embeddings
@@ -17,81 +16,10 @@ import chromadb
 #         embeddings = [model.encode(x) for x in texts]
 #         embeddings = [model.encode(x) for x in texts]
 #         return embeddings
 #         return embeddings
 
 
-class BaseLoader(ABC):
-
-    def load():
-        pass
-
-
-class TextMindLoader(BaseLoader):
-
-    def __init__(self, file_path: Union[str, Path]):
-        self.file_path = file_path
-
-    def lazy_load(self) -> Iterator[Document]:
-        with open(self.file_path, 'r', encoding='utf-8') as jsonfile:
-            raw_json = json.load(jsonfile)
-
-        for page in raw_json['pages']:
-            page_content = page['text']
-            metadata = {'page_num': page['page_num'], 'has_table': page['tables'] is not None}
-            yield Document(page_content=page_content, metadata=metadata)
-
-
-class DocsLoader():
-
-    @classmethod
-    def textmind_loader(cls, filepath):
-        loader = TextMindLoader(file_path=filepath)
-        data = loader.load()
-        return data
-
-
-class Chroma:
-    def __init__(self, persist_directory, embedding_function):
+class VectorClient:
+    def __init__(self):
         super().__init__()
         super().__init__()
 
 
-    @classmethod
-    def from_documents(cls, docs, embedding_model, persist_directory):
-        client = chromadb.HttpClient(host='localhost', port=50000)
-        collection = client.get_or_create_collection(persist_directory, embedding_function=embedding_function)
-        # 向集合添加数据
-        for index, item in enumerate(pages_content):
-            # await collection.add(
-            collection.add(
-                documents=[
-                    item.get('text')
-                ],
-                metadatas=[
-                    {"page_number": item.get('page_number')}
-                ],
-                ids=[
-                    str(index)
-                ],
-                # embeddings=[
-                #     [1.1, 2.3, 3.2]
-                # ]
-            )
-        return collection
-
-
-class EmbeddingVectorDB():
-    @classmethod
-    def load_local_embedding_model(cls, embedding_model_path, device='cpu'):
-        """加载本地向量模型"""
-        embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_path, model_kwargs={'device': device})
-        return embedding_model
-
-    @classmethod
-    def chroma_vector_db(cls, split_docs, vector_db_path, embedding_model):
-        if os.path.exists(vector_db_path):
-            print('加载向量数据库路径 =》', vector_db_path)
-            db = Chroma(persist_directory=vector_db_path, embedding_function=embedding_model)
-        else:
-            print('创建向量数据库路径 =》', vector_db_path)
-            db = Chroma.from_documents(split_docs, embedding_model, persist_directory=vector_db_path)
-        return db
-
     @classmethod
     @classmethod
     def add(cls, pages_content: List[dict], project_name: Optional[str], file_name: Optional[str]) -> None:
     def add(cls, pages_content: List[dict], project_name: Optional[str], file_name: Optional[str]) -> None:
         """
         """
@@ -184,46 +112,56 @@ class EmbeddingVectorDB():
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    # import os
-    # import json
-    # import asyncio
-    # from glob import glob
-    # from pprint import pprint
-
-    # base_dir = 'D:\\desktop\\三峡水利\\data\\projects'
-    # base_dir = 'D:\\desktop\\三峡水利\\data\\0预审查初审详审测试数据'
-
-    # proj_name = '三峡左岸及地下电站地坪整治'
-    # proj_name = '【231100334287】500kV交流GIS GIL固体绝缘沿面失效机理研究及综合诊断装备研制采购程序文件'
-
-    # for supplierdir in glob(os.path.join(os.path.join(base_dir, proj_name), '*')):
-    #     for file in glob(os.path.join(supplierdir, '*-content.json')):
-    #         supplier = file.split('\\')[-2]
-
-    #         # with open(file, 'r', encoding='utf-8') as fp:
-    #         #     content = json.load(fp)
-
-    #         # # asyncio.run(insert(pages_content = content, project_name=proj_name, file_name=supplier))
-
-    #         # # asyncio.run(
-    #         # #     search(
-    #         # #         project_name='三峡左岸及地下电站地坪整治',
-    #         # #         file_name='湖北建新建设工程有限公司',
-    #         # #         query=['财务审计报告'],
-    #         # #         contains='2021年'
-    #         # #     )
-    #         # # )
-
-    #         # VectorClient.add(
-    #         #     pages_content=content,
-    #         #     project_name=proj_name,
-    #         #     file_name=supplier
-    #         # )
-
-    #         r = VectorClient.query(
-    #             project_name=proj_name,
-    #             file_name=supplier,
-    #             query=['查询设备特性及性能组织'],
-    #             contains='设备性能'
-    #         )
-    #         pprint(r)
+    import os
+    import json
+    import asyncio
+    from glob import glob
+    from pprint import pprint
+
+    base_dir = 'D:\\desktop\\三峡水利\\data\\projects'
+    base_dir = 'D:\\desktop\\三峡水利\\data\\0预审查初审详审测试数据'
+
+    proj_name = '三峡左岸及地下电站地坪整治'
+    proj_name = '【231100334287】500kV交流GIS GIL固体绝缘沿面失效机理研究及综合诊断装备研制采购程序文件'
+
+    for supplierdir in glob(os.path.join(os.path.join(base_dir, proj_name), '*')):
+        for file in glob(os.path.join(supplierdir, '*-content.json')):
+            supplier = file.split('\\')[-2]
+
+            # with open(file, 'r', encoding='utf-8') as fp:
+            #     content = json.load(fp)
+
+            # # asyncio.run(insert(pages_content = content, project_name=proj_name, file_name=supplier))
+
+            # asyncio.run(
+            #     search(
+            #         project_name='三峡左岸及地下电站地坪整治',
+            #         file_name='湖北建新建设工程有限公司',
+            #         query=['财务审计报告'],
+            #         contains='2021年'
+            #     )
+            # )
+
+            # VectorClient.add(
+            #     pages_content=content,
+            #     project_name=proj_name,
+            #     file_name=supplier
+            # )
+
+            # r = VectorClient.query(
+            #     project_name=proj_name,
+            #     file_name=supplier,
+            #     # query=['查询设备特性及性能组织'],
+            #     query=['证书'],
+            #     # contains='设备性能'
+            #     contains='二级'
+            # )
+            # pprint(r)
+
+    # r = VectorClient.query(
+    #     project_name=proj_name,
+    #     file_name='supplier',
+    #     query=['证书'],
+    #     contains='营业执照'
+    # )
+    # pprint(r)