app.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/python
  2. # -*- coding=utf-8 -*-
  3. # @Create Time: 2024-08-05 15:12:31
  4. # @Last Modified time: 2024-12-25 16:17:05
  5. import os
  6. os.environ['TRANSFORMERS_OFFLINE'] = '1'
  7. import uuid
  8. from typing import List, Literal, Optional, Union
  9. import torch
  10. import uvicorn
  11. import numpy as np
  12. from pydantic import BaseModel, Field
  13. from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks, HTTPException
  14. # from flask import Flask, jsonify, request
  15. from transformers import AutoTokenizer, AutoModel
  16. from celery_tasks.all_instance import detail_task
  17. app = FastAPI()
  18. # app = Flask(__name__)
  19. class ModelLoader:
  20. _instance = None
  21. def __new__(cls):
  22. if cls._instance is None:
  23. cls._instance = super().__new__(cls)
  24. # 加载模型代码
  25. cls._instance.model = AutoModel.from_pretrained("GanymedeNil/text2vec-base-chinese")
  26. cls._instance.tokenizer = AutoTokenizer.from_pretrained("GanymedeNil/text2vec-base-chinese")
  27. return cls._instance
  28. # model_loader = ModelLoader()
  29. #class EmbeddingInput(BaseModel):
  30. # input: str
  31. # model: Optional[str] = "text2vec-base-chinese"
  32. #class Embeding(BaseModel):
  33. # embedding: Optional[list] = []
  34. # index: Optional[int] = 0
  35. # object: Optional[str] = 'embedding'
  36. class Usage(BaseModel):
  37. prompt_tokens: int
  38. total_tokens: int
  39. #class EmbedingResponse(BaseModel):
  40. # data: List[Embeding]
  41. # model: Optional[str] = 'text2vec-base-chinese'
  42. # object: Optional[str] = 'list'
  43. # usage: Usage
  44. class ResponseModel(BaseModel):
  45. error_code: Optional[int] = 0
  46. error_msg: Optional[str] = ''
  47. log_id: Optional[str] = ''
  48. result: Optional[dict] = {}
  49. # @app.post('/v1/embeddings', response_model=EmbedingResponse)
  50. # async def create_embeding(request: EmbeddingInput):
  51. # encoded_input = model_loader.tokenizer(request.input, return_tensors='pt')
  52. # with torch.no_grad():
  53. # output = model_loader.model(**encoded_input)
  54. # text_embedding = np.mean(output.last_hidden_state.mean(dim=1).numpy(), axis=0).tolist()
  55. # return EmbedingResponse(
  56. # data=[
  57. # Embeding(embedding=text_embedding)
  58. # ],
  59. # usage=Usage(
  60. # prompt_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1],
  61. # total_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1]
  62. # )
  63. # )
  64. # @app.route('/v1/embeddings', methods=['POST'])
  65. # def create_embeding(request: EmbeddingInput):
  66. # encoded_input = model_loader.tokenizer(request.input, return_tensors='pt')
  67. # with torch.no_grad():
  68. # output = model_loader.model(**encoded_input)
  69. # text_embedding = np.mean(output.last_hidden_state.mean(dim=1).numpy(), axis=0).tolist()
  70. # return EmbedingResponse(
  71. # data=[
  72. # Embeding(embedding=text_embedding)
  73. # ],
  74. # usage=Usage(
  75. # prompt_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1],
  76. # total_tokens=encoded_input.input_ids.shape[0] * encoded_input.input_ids.shape[1]
  77. # )
  78. # )
  79. @app.post('/detail_check', response_model=ResponseModel)
  80. async def predict(background_task: BackgroundTasks, projectId: str = Form(), projectName: str = Form(), bidderUnit: str = Form(), zb_filename: str = Form(), tb_filename: str = Form(), files: List[UploadFile] = File(...)):
  81. os.makedirs('./tmp', exist_ok=True)
  82. for file in files:
  83. if file.filename == zb_filename:
  84. print('招标文件')
  85. zb_file = f'./tmp/zb_file-{uuid.uuid4()}.pdf'
  86. zb_res = await file.read()
  87. with open(zb_file, 'wb') as f:
  88. f.write(zb_res)
  89. elif file.filename == tb_filename:
  90. print('投标文件')
  91. tb_file = f'./tmp/tb_file-{uuid.uuid4()}.json'
  92. tb_res = await file.read()
  93. with open(tb_file, 'wb') as f:
  94. f.write(tb_res)
  95. else:
  96. return ResponseModel(error_code=1, error_msg='未识别文件')
  97. background_task.add_task(detail_task, zb_file=zb_file, tb_file=tb_file, tb_filename=tb_filename, projectId=projectId, project=projectName, supplier=bidderUnit)
  98. return ResponseModel(result={"task_id": f"{uuid.uuid4()}"})
  99. # @app.route('/detail_check', methods=['POST'])
  100. # def predict():
  101. # tb_file = request.files['tb']
  102. # zb_file = request.files['zb']
  103. # tb_bytes = tb_file.read()
  104. # zb_bytes = zb_file.read()
  105. # return ResponseModel(result={"task_id": "T000001"})
  106. if __name__ == '__main__':
  107. uvicorn.run(app, host='0.0.0.0', port=18883)