sprivacy 11 ماه پیش
والد
کامیت
41f52ff7d0
6فایلهای تغییر یافته به همراه104 افزوده شده و 25 حذف شده
  1. 62 0
      api.py
  2. 5 3
      celery_tasks/commonprocess.py
  3. 2 2
      celery_tasks/extract_financial_report.py
  4. 2 2
      celery_tasks/get_info.py
  5. 29 4
      celery_tasks/matcher.py
  6. 4 14
      celery_tasks/technical_part.py

+ 62 - 0
api.py

@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# @Author: privacy
+# @Date:   2024-09-03 11:24:56
+# @Last Modified by:   privacy
+# @Last Modified time: 2024-09-04 11:07:49
+from fastapi import FastAPI
+from pydantic import BaseModel
+from celery.result import AsyncResult
+
+from celery_tasks import celery_app
+# from celery_tasks.commonprocess import add
+from celery_tasks.commonprocess import bidding_factor, test_all_files
+from celery_tasks.project_loc import extract_project
+
+tags_metadata = [
+    {
+        "name": "file",
+        "description": "解析PDF文件"
+    },
+    {
+        "name": "factor",
+        "description": "解析详审因素"
+    },
+    {
+        "name": "result",
+        "description": "获取异步任务结果"
+    }
+]
+app = FastAPI(openapi_tags=tags_metadata)
+
+
+class RequestModel(BaseModel):
+    table_list: list
+
+
+@app.post('/')
+def root(request: RequestModel):
+    # task = add.delay(12, 12)
+    task = extract_project.apply_async(kwargs={'table_list': request.table_list})
+    return {"message": f"Task {task.id} Start!"}
+
+
+@app.post('/factor', tags=['factor'])
+def get_factor(request: RequestModel):
+    task = bidding_factor.apply_async(kwargs={'table_list': request.table_list})
+    return {"message": f"Task {task.id} Start!"}
+
+
+@app.get('/result', tags=['result'])
+def back(taskid):
+    result = AsyncResult(id=taskid, app=celery_app)
+    if result.successful():
+        val = result.get()
+        return "执行完成,结果:%s" % val
+    else:
+        return '正在处理中...'
+
+
+@app.get('/file', tags=['file'])
+def process_file(proj_name: str):
+    task = test_all_files.apply_async(kwargs={'proj_name': proj_name})
+    return {"message": f"Task {task.id} Start!"}

+ 5 - 3
celery_tasks/commonprocess.py

@@ -2,7 +2,7 @@
 # @Author: privacy
 # @Date:   2024-08-30 13:13:03
 # @Last Modified by:   privacy
-# @Last Modified time: 2024-09-04 17:33:02
+# @Last Modified time: 2024-09-06 09:25:00
 import os
 from glob import glob
 from typing import List, Optional
@@ -88,8 +88,10 @@ def bidding_factor(table_list: list) -> dict:
     """
     dpr = DocumentPreReview()
     dpr.Bidding_tables = table_list
-
-    return dpr.get_table()
+    try:
+        return dpr.get_table()
+    except Exception:
+        return {}
 
 
 @celery_app.task

+ 2 - 2
celery_tasks/extract_financial_report.py

@@ -2,7 +2,7 @@
 # @Author: privacy
 # @Date:   2024-06-11 13:43:14
 # @Last Modified by:   privacy
-# @Last Modified time: 2024-09-03 10:10:49
+# @Last Modified time: 2024-09-05 15:04:14
 import os
 import re
 import datetime
@@ -85,7 +85,7 @@ def extract_financial_report(title_list: list, table_list: list, image_list: lis
             ]
 
             ocr_results = [
-                pic_ocr.apply_async(kwargs={'image_path': img['image_name']}).get(timeout=30)
+                pic_ocr.apply_async(kwargs={'image_path': img['image_name']}).get(timeout=30)['rawjson']['ret']
                 for img in item.get('images')
             ]
 

+ 2 - 2
celery_tasks/get_info.py

@@ -2,7 +2,7 @@
 # @Author: privacy
 # @Date:   2024-06-11 13:43:14
 # @Last Modified by:   privacy
-# @Last Modified time: 2024-09-04 12:08:01
+# @Last Modified time: 2024-09-05 16:29:06
 
 # 标准包导入
 import os
@@ -598,7 +598,7 @@ class PdfExtractAttr(object):
                         text_type = False
 
                     # 判断是否为表名
-                    if text and text.endswith('表'):
+                    if text and (text.endswith('表') or text.startswith('表') or text.endswith('清单')):
                         is_table_name = True
                     else:
                         is_table_name = False

+ 29 - 4
celery_tasks/matcher.py

@@ -2,10 +2,10 @@
 # @Author: privacy
 # @Date:   2024-06-27 09:33:01
 # @Last Modified by:   privacy
-# @Last Modified time: 2024-09-05 10:38:48
+# @Last Modified time: 2024-09-06 14:12:50
 import os
 os.environ['TRANSFORMERS_OFFLINE'] = '1'
-from typing import List
+from typing import List, Union
 
 import torch
 import numpy as np
@@ -14,7 +14,6 @@ from sklearn.metrics.pairwise import cosine_similarity
 from transformers import AutoTokenizer, AutoModel
 
 
-
 class Matcher:
     def __init__(self):
         # Load model directly
@@ -22,7 +21,7 @@ class Matcher:
         self.tokenizer = AutoTokenizer.from_pretrained("GanymedeNil/text2vec-base-chinese")
         self.model = AutoModel.from_pretrained("GanymedeNil/text2vec-base-chinese")
 
-    def TopK1(self, title: str, keywords: list, query_embedding, option_embeddings: list) -> pd.Series:
+    def TopK1(self, title: str, keywords: list, query_embedding: np.ndarray, option_embeddings: List[np.ndarray]) -> pd.Series:
         """
         获取相似度最高的向量
         Args:
@@ -71,6 +70,32 @@ class Matcher:
             text_embeddings.append(np.mean(output.last_hidden_state.mean(dim=1).numpy(), axis=0))
         return text_embeddings
 
+    @classmethod
+    def mean_pooling(cls, token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            token_embeddings: First element of model_output contains all token embeddings
+        """
+        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+    def sentence_embeddings(self, sentence: Union[str, List[str]]) -> torch.Tensor:
+        encoded_input = self.tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
+        with torch.no_grad():
+            model_output = self.model(**encoded_input)
+        return self.mean_pooling(model_output[0], encoded_input['attention_mask'])
+
+    def similarities(self, sentence: Union[str, List[str]], query: str, topk: int = 1) -> pd.DataFrame:
+        sentence_matrix = self.sentence_embeddings(sentence)
+        query_vector = self.sentence_embeddings(query)
+        cosine_similarities = cosine_similarity(query_vector, sentence_matrix)
+        similarity_df = pd.DataFrame(cosine_similarities[0], columns=['similarity'])
+        return similarity_df
+        # df_with_similarity = pd.concat([sentence, similarity_df], axis=1).sort_values(by='similarity', ascending=False)
+        # threshold = 0.7
+        # result = df_with_similarity[df_with_similarity['similarity'] > threshold]
+        # return result.head(topk)
+
 
 if __name__ == '__main__':
     matcher = Matcher()

+ 4 - 14
celery_tasks/technical_part.py

@@ -2,12 +2,13 @@
 # @Author: privacy
 # @Date:   2024-08-30 11:15:24
 # @Last Modified by:   privacy
-# @Last Modified time: 2024-09-04 14:48:03
+# @Last Modified time: 2024-09-06 09:20:10
 
 """
 技术部分
 """
 from . import celery_app
+from .commonprocess import bidding_document, bidding_factor
 
 
 @celery_app.task
@@ -26,22 +27,11 @@ def main(bidding_file, tender_file):
     result = task.get(timeout=3600)
     # 2、从招标表格中抽取评分因素
     task = bidding_factor.apply_async(kwargs={'table_list': result['tables']})
-    # 3、获取商务部分评分标准
+    # 3、获取技术部分评分标准
     for item in task.get(timeout=1)['技术部分评分标准']:
         print(item['评分因素'], item['评分标准'], item['权重'])
-        if '业绩' in item['评分因素']:
-            pass
-        elif '信用' in item['评分因素']:
-            pass
-        elif '财务' in item['评分因素']:
-            pass
-        elif '报价' in item['评分因素']:
-            pass
-        else:
-            pass
-    # 4、根据商务部分评分标准查找投标文件内容位置
+    # 4、根据技术部分评分标准查找投标文件内容位置
     # 5、返回评标结果
-    pass
 
 
 if __name__ == '__main__':