123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- #!/usr/bin/python
- # -*- coding=utf-8 -*-
- # @Create Time: 2024-08-05 15:12:31
- # @Last Modified time: 2024-12-25 16:17:05
- import os
- os.environ['TRANSFORMERS_OFFLINE'] = '1'
- import uuid
- from typing import List, Literal, Optional, Union
- import torch
- import uvicorn
- import numpy as np
- from pydantic import BaseModel, Field
- from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks, HTTPException
- # from flask import Flask, jsonify, request
- from transformers import AutoTokenizer, AutoModel
- from celery_tasks.all_instance import detail_task
- app = FastAPI()
- # app = Flask(__name__)
- class ModelLoader:
- _instance = None
- def __new__(cls):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- # 加载模型代码
- cls._instance.model = AutoModel.from_pretrained("GanymedeNil/text2vec-base-chinese")
- cls._instance.tokenizer = AutoTokenizer.from_pretrained("GanymedeNil/text2vec-base-chinese")
- return cls._instance
- # model_loader = ModelLoader()
- #class EmbeddingInput(BaseModel):
- # input: str
- # model: Optional[str] = "text2vec-base-chinese"
- #class Embeding(BaseModel):
- # embedding: Optional[list] = []
- # index: Optional[int] = 0
- # object: Optional[str] = 'embedding'
- class Usage(BaseModel):
- prompt_tokens: int
- total_tokens: int
- #class EmbedingResponse(BaseModel):
- # data: List[Embeding]
- # model: Optional[str] = 'text2vec-base-chinese'
- # object: Optional[str] = 'list'
- # usage: Usage
- class ResponseModel(BaseModel):
- error_code: Optional[int] = 0
- error_msg: Optional[str] = ''
- log_id: Optional[str] = ''
- result: Optional[dict] = {}
- # @app.post('/v1/embeddings', response_model=EmbedingResponse)
- # async def create_embeding(request: EmbeddingInput):
- # encoded_input = model_loader.tokenizer(request.input, return_tensors='pt')
- # with torch.no_grad():
- # output = model_loader.model(**encoded_input)
- # text_embedding = np.mean(output.last_hidden_state.mean(dim=1).numpy(), axis=0).tolist()
- # return EmbedingResponse(
- # data=[
- # Embeding(embedding=text_embedding)
- # ],
- # usage=Usage(
- # prompt_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1],
- # total_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1]
- # )
- # )
- # @app.route('/v1/embeddings', methods=['POST'])
- # def create_embeding(request: EmbeddingInput):
- # encoded_input = model_loader.tokenizer(request.input, return_tensors='pt')
- # with torch.no_grad():
- # output = model_loader.model(**encoded_input)
- # text_embedding = np.mean(output.last_hidden_state.mean(dim=1).numpy(), axis=0).tolist()
- # return EmbedingResponse(
- # data=[
- # Embeding(embedding=text_embedding)
- # ],
- # usage=Usage(
- # prompt_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1],
- # total_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1]
- # )
- # )
- @app.post('/detail_check', response_model=ResponseModel)
- async def predict(background_task: BackgroundTasks, projectId: str = Form(), projectName: str = Form(), bidderUnit: str = Form(), zb_filename: str = Form(), tb_filename: str = Form(), files: List[UploadFile] = File(...)):
- os.makedirs('./tmp', exist_ok=True)
- for file in files:
- if file.filename == zb_filename:
- print('招标文件')
- zb_file = f'./tmp/zb_file-{uuid.uuid4()}.pdf'
- zb_res = await file.read()
- with open(zb_file, 'wb') as f:
- f.write(zb_res)
- elif file.filename == tb_filename:
- print('投标文件')
- tb_file = f'./tmp/tb_file-{uuid.uuid4()}.json'
- tb_res = await file.read()
- with open(tb_file, 'wb') as f:
- f.write(tb_res)
- else:
- return ResponseModel(error_code=1, error_msg='未识别文件')
- background_task.add_task(detail_task, zb_file=zb_file, tb_file=tb_file, tb_filename=tb_filename, projectId=projectId, project=projectName, supplier=bidderUnit)
- return ResponseModel(result={"task_id": f"{uuid.uuid4()}"})
- # @app.route('/detail_check', methods=['POST'])
- # def predict():
- # tb_file = request.files['tb']
- # zb_file = request.files['zb']
- # tb_bytes = tb_file.read()
- # zb_bytes = zb_file.read()
- # return ResponseModel(result={"task_id": "T000001"})
- if __name__ == '__main__':
- uvicorn.run(app, host='0.0.0.0', port=18883)
|