online_test.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964
  1. ### 解析所有pdf文件并提取信息进行测试的框架
  2. ### PdfExtractAttr作为提取pdf信息的基类
  3. # 子类在其基础上实现匹配功能
  4. # 标准包导入
  5. import os
  6. import re
  7. import json
  8. import re
  9. import shutil
  10. import pandas as pd
  11. import pdb
  12. import base64
  13. from io import BytesIO
  14. from pprint import pprint
  15. # 第三方包导入
  16. import numpy as np
  17. import pandas as pd
  18. import cv2
  19. import torch
  20. import glob
  21. import logging
  22. import requests
  23. import time
  24. import datetime
  25. from tqdm import tqdm
  26. from tools import RefPageNumberResolver
  27. from get_info import PdfExtractAttr
  28. from get_info import is_title, export_image, _save_jpeg, _save_jpeg2000, _save_bmp, main_parse, table_parse, load_json
  29. from PIL import Image
  30. import cn_clip.clip as clip
  31. from cn_clip.clip import load_from_name, available_models
  32. from pdfminer.image import ImageWriter
  33. # global envs
  34. device = "cuda" if torch.cuda.is_available() else "cpu"
  35. clip_version = "ViT-B-16"
  36. model, preprocess = load_from_name(clip_version)
  37. model.eval()
  38. log_path = "/home/stf/miner_pdf/test.log"
  39. # log
  40. def create_logger(log_path):
  41. """
  42. 将日志输出到日志文件和控制台
  43. """
  44. logger = logging.getLogger()
  45. logger.setLevel(logging.INFO)
  46. formatter = logging.Formatter(
  47. '%(asctime)s - %(levelname)s - %(message)s')
  48. # 创建一个handler,用于写入日志文件
  49. file_handler = logging.FileHandler(
  50. filename=log_path, mode='w')
  51. file_handler.setFormatter(formatter)
  52. file_handler.setLevel(logging.INFO)
  53. logger.addHandler(file_handler)
  54. # 创建一个handler,用于将日志输出到控制台
  55. console = logging.StreamHandler()
  56. console.setLevel(logging.DEBUG)
  57. console.setFormatter(formatter)
  58. logger.addHandler(console)
  59. return logger
  60. logger = create_logger(log_path=log_path)
  61. # ocr外部接口
  62. class OcrAgent():
  63. def __init__(self, url):
  64. self.url = url
  65. self.datetime_re = r'\d{4}年\d{1,2}月\d{1,2}日至(?:\d{4}年\d{1,2}月\d{1,2}日|长期)'
  66. # 不同类型证书资质正则
  67. self.re_dict = {
  68. "business_license" : r'营业执照',
  69. "deposit": r'^(?:开户许可证|[\u4e00-\u9fff]+存款账户[\u4e00-\u9fff]+)$',
  70. "production_license": r'\b[\u4e00-\u9fff]*许可证\b',
  71. "qualtifications" : r'\b[\u4e00-\u9fff]*证书',
  72. "proof": r'\b[\u4e00-\u9fff]*证明',
  73. }
  74. self.sign_threshold = 0.05
  75. def get_content(self, image_path):
  76. try:
  77. with open(image_path, 'rb') as image_file:
  78. files = {"file": ("image.jpg", image_file, "image/jpeg")}
  79. # files = {"file": ("image.png", image_file, "image/png")}
  80. response = requests.post(url, files=files)
  81. return response.json()
  82. except:
  83. raise ValueError(f"传入图像{image_path}已损坏")
  84. def remove_red_seal(self, image_path):
  85. # 读取图像
  86. input_img = cv2.imread(image_path)
  87. # 分离图片的通道
  88. blue_c, green_c, red_c = cv2.split(input_img)
  89. #利用大津法自动选择阈值
  90. thresh, ret = cv2.threshold(red_c, 0, 255, cv2.THRESH_OTSU)
  91. #对阈值进行调整
  92. filter_condition = int(thresh * 1.0)
  93. #移除红色的印章
  94. _, red_thresh = cv2.threshold(red_c, filter_condition, 255, cv2.THRESH_BINARY)
  95. # 把图片转回3通道
  96. result_img = np.expand_dims(red_thresh, axis=2)
  97. result_img = np.concatenate((result_img, result_img, result_img), axis=-1)
  98. return result_img
  99. def judge(self, image_path: str, firm_name: str):
  100. '''使用正则判断是否属于营业执照或资质证书类型'''
  101. # image_prefix = image_path.split('/')[-1][:-4]
  102. image_prefix = image_path.split('/')[-1]
  103. logger.info(f'processing img: {image_prefix}')
  104. # page_number = image_prefix.split('_')[-2]
  105. response_item = {
  106. "qualtified": None, # 是否为证书
  107. "matched": None, # 是否出现匹配的公司名称
  108. "license_name": None, # 证书名
  109. # "page_number": page_number, # 证书所在页
  110. "start_datetime": None, # 有效起始时间
  111. "end_datetime": None # 有效终止时间
  112. }
  113. content = self.get_content(image_path=image_path)
  114. image_info = content["rawjson"]["ret"]
  115. # 必须包含公司名称信息
  116. if not self.search(image_info=image_info, key=firm_name):
  117. return None
  118. else:
  119. response_item['matched'] = True
  120. # 是否匹配营业执照或资质证书
  121. for key, format in self.re_dict.items():
  122. if key == 'business_license':
  123. match_name = self.re_match(image_info=image_info, format=format)
  124. else:
  125. match_name = self.re_search(image_info=image_info, format=format)
  126. if match_name and key == 'business_license':
  127. response_item["qualtified"] = True
  128. response_item["license_name"] = match_name
  129. response_item = self.find_license_datetime(image_info=image_info, response_item=response_item)
  130. return response_item
  131. elif match_name:
  132. response_item["qualtified"] = True
  133. response_item["license_name"] = match_name
  134. response_item = self.find_certificate_datetime(image_info=image_info, response_item=response_item)
  135. return response_item
  136. return response_item
  137. # TODO 资质证书有效期定位
  138. def find_certificate_datetime(self, image_info, response_item):
  139. # keyword
  140. start_keywords = ['颁发日期', '发证日期', '生效日期']
  141. end_keywords = ['终止日期']
  142. priority_keywords = ['有效期', '使用期限', '有效日期']
  143. keywords_list = ['有效期', '使用期限', '有效日期', '终止日期', '颁发日期', '发证日期', '生效日期']
  144. # re format
  145. format = r'(?:[自至])?\d{4}年\d{1,2}月\d{1,2}日(?:至)?(?:\d{4}年\d{1,2}月\d{1,2}日)?'
  146. special_format = r'\d{4}-\d{1,2}-\d{1,2}'
  147. # 判断是否存在日期关键字
  148. flag = False
  149. keyword_dict = {}
  150. for info in image_info:
  151. word = info['word']
  152. left = info['rect']['left']
  153. top = info['rect']['top']
  154. width = info['rect']['width']
  155. height = info['rect']['height']
  156. for keyword in keywords_list:
  157. # 该证书存在日期关键字
  158. if keyword in word:
  159. flag = True
  160. charset_list = info['charset']
  161. for char_dc in charset_list:
  162. if char_dc['word'] == keyword[-1]:
  163. right = char_dc['rect']['left'] + char_dc['rect']['width']
  164. keyword_dict[keyword] = {
  165. "left": left,
  166. "top": top,
  167. "right": right
  168. }
  169. if flag:
  170. for info in image_info:
  171. word = info['word']
  172. if '年' in word or re.search(r'\d', word):
  173. left = info['rect']['left']
  174. top = info['rect']['top']
  175. width = info['rect']['width']
  176. if '年' in word:
  177. find_list = re.findall(pattern=format, string=word)
  178. else:
  179. find_list = re.findall(pattern=special_format, string=word)
  180. logger.info(f'word {word} has find_list{find_list}')
  181. # if self.check:
  182. # pdb.set_trace()
  183. if len(find_list) == 1:
  184. find_string = find_list[0]
  185. if '至' in find_string:
  186. start_prefix = find_string.split('至')[0].replace('自', '')
  187. end_prefix = find_string.split('至')[-1]
  188. if '年' in start_prefix:
  189. response_item['start_datetime'] = start_prefix
  190. if end_prefix != '':
  191. response_item['end_datetime'] = end_prefix
  192. return response_item
  193. # 不存在{至}的情况下通过位置和已有期限关键字来分配日期
  194. else:
  195. for k, k_info in keyword_dict.items():
  196. k_left = k_info['left']
  197. k_right = k_info['right']
  198. k_top = k_info['top']
  199. # 捕获关键字
  200. if left == k_left:
  201. if (k in priority_keywords) or (k in end_keywords) and response_item['end_datetime'] is None:
  202. response_item['end_datetime'] = find_string
  203. elif k in start_keywords and response_item['start_datetime'] is None:
  204. response_item['start_datetime'] = find_string
  205. break
  206. elif left >= k_right and top >= k_top:
  207. if (k in priority_keywords) or (k in end_keywords) and response_item['end_datetime'] is None:
  208. response_item['end_datetime'] = find_string
  209. elif k in start_keywords and response_item['start_datetime'] is None:
  210. response_item['start_datetime'] = find_string
  211. elif len(find_list) == 2:
  212. start_prefix = find_list[0].replace('自', '')
  213. end_prefix = find_list[-1].replace('至', '')
  214. if response_item['start_datetime'] is None:
  215. response_item['start_datetime'] = start_prefix
  216. if response_item['end_datetime'] is None:
  217. response_item['end_datetime'] = end_prefix
  218. else:
  219. logger.info(f'wrong word: {word} ...')
  220. else:
  221. continue
  222. return response_item
  223. # 找到营业执照中id与date信息
  224. def find_license_datetime(self, image_info, response_item):
  225. for info in image_info:
  226. word = info['word']
  227. # id
  228. if (word.startswith('证照编号:') and len(word) == 25) or (word.isdigit() and len(word) == 20):
  229. response_item['id'] = word if word.isdigit() else word[5:]
  230. elif bool(re.match(self.datetime_re, word)):
  231. split = word.split('至')
  232. start_datetime = split[0]
  233. end_datetime = split[-1]
  234. response_item['start_datetime'] = start_datetime
  235. response_item['end_datetime'] = end_datetime
  236. elif word == '长期':
  237. response_item['start_datetime'] = response_item['end_datetime'] = '长期'
  238. return response_item
  239. # 在image_info中搜寻word中包含key的内容
  240. def search(self, image_info, key):
  241. for info in image_info:
  242. word = info['word']
  243. if key in word:
  244. return True
  245. return False
  246. # 在image_info中使用re.search搜寻满足{format}正则的信息
  247. def re_search(self, image_info, format):
  248. for info in image_info:
  249. word = info['word']
  250. match = re.search(format, word)
  251. if match:
  252. return match.group(0)
  253. return False
  254. # 在image_info中使用re.match搜寻满足{format}正则的信息
  255. def re_match(self, image_info, format):
  256. for info in image_info:
  257. word = info['word']
  258. match = re.match(format, word)
  259. if match:
  260. return word
  261. return False
  262. # 用于识别固定位置是否有公司法人签名
  263. def signature_recognition(self, image_path: str):
  264. keywords = ['投标函', '(法定代表人CA电子印章)','(法定代表人CA电子印章或签字)', '(签字)', '法定代表人或其委托代理人:', '法定代表人:']
  265. key_pos = {}
  266. image_prefix = image_path.split('/')[0]
  267. image_name = image_path.split('/')[-1][:-4]
  268. removed_image_name = image_name + '_roi' + image_path.split('/')[-1][-4:]
  269. ink_image_name = image_name + '_ink' + image_path.split('/')[-1][-4:]
  270. removed_image_path = os.path.join(image_prefix, removed_image_name)
  271. ink_image_path = os.path.join(image_prefix, ink_image_name)
  272. if not os.path.exists(removed_image_path):
  273. removed_seal_img = self.remove_red_seal(image_path=image_path)
  274. cv2.imwrite(removed_image_name, removed_seal_img)
  275. else:
  276. removed_seal_img = cv2.imread(removed_image_path)
  277. content = self.get_content(image_path=removed_image_path)
  278. image_info = content["rawjson"]["ret"]
  279. for info in image_info:
  280. word = info['word']
  281. left = info['rect']['left']
  282. top = info['rect']['top']
  283. width = info['rect']['width']
  284. height = info['rect']['height']
  285. right = left + width
  286. bottom = top + height
  287. for keyword in keywords:
  288. if keyword in word:
  289. key_pos[keyword] = {
  290. "word": word,
  291. "left": left,
  292. "right": right,
  293. "top": top,
  294. "bottom": bottom
  295. }
  296. break
  297. # 如果不存在"投标函"、"法定代表人"等关键字,则返回False
  298. if len(key_pos) == 0:
  299. return False
  300. # 定位到法定代表人所在位置
  301. if ((key_pos.get('法定代表人:') is not None) or (key_pos.get('法定代表人或其委托代理人:') is not None)) and \
  302. ((key_pos.get('(法定代表人CA电子印章)') is not None) or (key_pos.get('(法定代表人CA电子印章或签字)') is not None) or (key_pos.get('(签字)') is not None)):
  303. if key_pos.get('法定代表人或其委托代理人:') is not None:
  304. l_info = key_pos['法定代表人或其委托代理人:']
  305. l_cnt = 13
  306. l_string = '法定代表人或其委托代理人:'
  307. else:
  308. l_info = key_pos['法定代表人:']
  309. l_cnt = 6
  310. l_string = '法定代表人:'
  311. if key_pos.get('(法定代表人CA电子印章)') is not None:
  312. r_info = key_pos['(法定代表人CA电子印章)']
  313. r_string = '(法定代表人CA电子印章)'
  314. elif key_pos.get('(法定代表人CA电子印章或签字)') is not None:
  315. r_info = key_pos['(法定代表人CA电子印章或签字)']
  316. r_string = '(法定代表人CA电子印章或签字)'
  317. else:
  318. r_info = key_pos['(签字)']
  319. r_string = '(签字)'
  320. # 此时签名应在两者之间
  321. l = l_info['right']
  322. l_word = l_info['word']
  323. r = r_info['left']
  324. r_word = r_info['word']
  325. t = max(l_info['top'], r_info['top'])
  326. b = min(l_info['bottom'], r_info['bottom']) - 5
  327. if l_word[-l_cnt:] != l_string or r_word != r_string:
  328. logger.info(l_word)
  329. logger.info(r_word)
  330. return True
  331. else:
  332. black_ratio = self.ink_recognition(
  333. input_img=removed_seal_img,
  334. out_path=ink_image_path,
  335. meta={
  336. "left": l,
  337. "right": r,
  338. "top": t,
  339. "bottom": b
  340. }
  341. )
  342. if black_ratio >= self.sign_threshold:
  343. return True
  344. return False
  345. elif (key_pos.get('(法定代表人CA电子印章)') is not None) or (key_pos.get('(法定代表人CA电子印章或签字)') is not None) or (key_pos.get('(签字)') is not None):
  346. # 此时签名应已包含
  347. if key_pos.get('(法定代表人CA电子印章)') is not None:
  348. key = key_pos['(法定代表人CA电子印章)']
  349. elif key_pos.get('(法定代表人CA电子印章或签字)') is not None:
  350. key = key_pos['(法定代表人CA电子印章或签字)']
  351. elif key_pos.get('(签字)') is not None:
  352. key = key_pos['(签字)']
  353. key_word = key['word']
  354. key_word = key_word.replace('(法定代表人CA电子印章)','').replace('(法定代表人CA电子印章或签字)', '').replace('(签字)','').replace('法定代表人或其委托代理人:', '').replace('法定代表人:', '')
  355. if key_word != '':
  356. logger.info(key_word)
  357. return True
  358. return False
  359. elif key_pos.get('法定代表人:') is not None:
  360. left = key_pos['法定代表人:']['left']
  361. right = key_pos['法定代表人:']['right']
  362. top = key_pos['法定代表人:']['top']
  363. bottom = key_pos['法定代表人:']['bottom']
  364. # cv2.rectangle(removed_seal_img, (left, top), (right, bottom), (255, 255, 0), 2) # 绿色框,线宽为2
  365. # 此时签名在右边或已包含
  366. word = key_pos['法定代表人:']['word']
  367. l = key_pos['法定代表人:']['right']
  368. r = l + 100
  369. t = key_pos['法定代表人:']['top']
  370. b = key_pos['法定代表人:']['bottom'] - 5
  371. if word[-6:] != '法定代表人:':
  372. logger.info(word)
  373. return True
  374. else:
  375. black_ratio = self.ink_recognition(
  376. input_img=removed_seal_img,
  377. out_path=ink_image_path,
  378. meta={
  379. "left": l,
  380. "right": r,
  381. "top": t,
  382. "bottom": b
  383. }
  384. )
  385. if black_ratio >= self.sign_threshold:
  386. return True
  387. return False
  388. elif key_pos.get('法定代表人或其委托代理人:') is not None:
  389. # 此时签名在右边或已包含
  390. word = key_pos['法定代表人或其委托代理人:']['word']
  391. l = key_pos['法定代表人或其委托代理人:']['right']
  392. r = l + 100
  393. t = key_pos['法定代表人或其委托代理人:']['top']
  394. b = key_pos['法定代表人或其委托代理人:']['bottom'] - 5
  395. if word[-13:] != '法定代表人或其委托代理人:':
  396. logger.info(word)
  397. return True
  398. else:
  399. black_ratio = self.ink_recognition(
  400. input_img=removed_seal_img,
  401. out_path=ink_image_path,
  402. meta={
  403. "left": l,
  404. "right": r,
  405. "top": t,
  406. "bottom": b
  407. }
  408. )
  409. if black_ratio >= self.sign_threshold:
  410. return True
  411. return False
  412. else:
  413. return False
  414. # 用于判断固定位置的长方形框内是否存在签名字迹
  415. def ink_recognition(self, input_img, out_path, meta: dict):
  416. left = meta["left"]
  417. right = meta["right"]
  418. top = meta["top"]
  419. bottom = meta["bottom"]
  420. crop_img = input_img[top:bottom, left:right, :]
  421. cv2.imwrite(out_path, crop_img)
  422. gray_img = cv2.cvtColor(crop_img, cv2.COLOR_BGR2GRAY)
  423. thresh, ret = cv2.threshold(gray_img, 0, 255, cv2.THRESH_OTSU)
  424. filter_condition = int(thresh * 0.90)
  425. _, black_thresh = cv2.threshold(gray_img, filter_condition, 255, cv2.THRESH_BINARY_INV)
  426. total_pixels = black_thresh.size
  427. black_pixels = np.count_nonzero(black_thresh)
  428. black_ratio = black_pixels / total_pixels
  429. return black_ratio
  430. '''
  431. # 由于水印遮挡导致ocr识别不准确会带来公司名称信息的缺失或冗余
  432. # 因此使用近似匹配,如果两者满足则判断为同一家公司
  433. def correct(self, string: str, firm_name: str):
  434. if '公司' not in string:
  435. return False
  436. # 尝试从string中匹配公司名称的部分
  437. format = r'.+有限(?:责任)?公司$'
  438. match_string = re.search(string, format)
  439. if not match_string:
  440. return False
  441. current_name = match_string.group(0).strip()
  442. invalid_prefixs = ['名称:', '名称:', '名称']
  443. # 匹配current_name中是否包含"名称"字样
  444. for invalid_prefix in invalid_prefixs:
  445. pos = current_name.find(invalid_prefix)
  446. if pos != -1:
  447. current_name = current_name[pos + len(invalid_prefix): ]
  448. break
  449. # 错识别返回false,漏识别返回true
  450. if current_name == firm_name:
  451. return True
  452. return False
  453. if len(current_name) < len(firm_name):
  454. lprefix = 0
  455. rprefix = 0
  456. transform = False
  457. for i in range(len(current_name)):
  458. c = current_name[i]
  459. word = firm_name[i]
  460. if c == word and not transform:
  461. lprefix += 1
  462. continue
  463. elif c == word and transform:
  464. rprefix += 1
  465. continue
  466. elif c != word and not transform:
  467. transform = True
  468. elif c != word and transform and rprefix > 0:
  469. return False
  470. return True
  471. '''
  472. # seal ocr外部接口,用于提取页面中的印章信息
  473. # 集成签名判断函数
  474. class seal_agent():
  475. def __init__(self,
  476. base_url: str,
  477. access_token: str,
  478. headers: dict,
  479. ):
  480. self.base_url = base_url
  481. self.access_token = access_token
  482. self.headers = headers
  483. self.request_url = base_url + access_token
  484. def seal_recognition(self, img_path):
  485. f = open(img_path, 'rb')
  486. img = base64.b64encode(f.read())
  487. params = {"image":img}
  488. response = requests.post(self.request_url, data=params, headers=self.headers)
  489. if response:
  490. data = response.json()
  491. else:
  492. data = {}
  493. return data
  494. class PdfMatcher(PdfExtractAttr):
  495. '''pdf匹配'''
  496. def __init__(self, file_path: str):
  497. super(PdfMatcher, self).__init__(
  498. file_path=file_path
  499. )
  500. # 投标书名称
  501. self.bid_name = file_path.split('/')[-1][:-4]
  502. # 投标书数据文件夹
  503. self.bid_dir = os.path.join(os.path.dirname(file_path), self.bid_name)
  504. # 公司名称
  505. self.firm_name = file_path.split('/')[-2]
  506. # title list
  507. title_path = os.path.join(self.bid_dir, "title.json")
  508. self.title = load_json(title_path)
  509. # outline list
  510. outline_path = os.path.join(self.bid_dir, "outlines.json")
  511. self.outline = self.parse_outline(out_path=outline_path)
  512. # text list
  513. text_path = os.path.join(self.bid_dir, "all_texts.json")
  514. self.details = self.parse_text(out_path=text_path)
  515. # table list
  516. table_path = os.path.join(self.bid_dir, "all_tables.json")
  517. if os.path.exists(table_path):
  518. self.table = load_json(table_path)
  519. else:
  520. self.tables = self.parse_table(out_path=table_path)
  521. # image list
  522. self.image_dir = os.path.join(self.bid_dir, "extracted_images")
  523. # image format
  524. self.image_format = "image_page_{}*"
  525. # image filter threshold
  526. self.start_threshold = 10
  527. self.distance_threshold = 6
  528. self.search_threshold = 20
  529. self.match_threshold = 44.0
  530. self.degrade_threshold = 42.0
  531. def search_interval(self):
  532. '''定位营业执照、资质证书的区间范围'''
  533. # 通过关键字模糊定位
  534. keywords = ['资格审查资料','资格审查材料','其它材料','其他材料','其他资料','附件', '影印件']
  535. search_interval = []
  536. # locate in title.json
  537. left_pos = -1 # 左指针
  538. right_pos = -1 # 右指针
  539. for title_block in self.title:
  540. block_text = title_block['text'].replace(' ', '').strip()
  541. # 先进行左区间判定
  542. if left_pos != -1 and '证书' not in block_text:
  543. right_pos = title_block['page_number']
  544. search_interval.append((left_pos, right_pos))
  545. # 重置
  546. left_pos = -1
  547. for keyword in keywords:
  548. if keyword in block_text:
  549. # 先进行模糊的outline定位
  550. center_page = None
  551. if '.' in block_text:
  552. center_page = block_text.split('.')[-1]
  553. if center_page.isdigit():
  554. center_page = eval(center_page)
  555. left_pos = min(title_block['page_number'], center_page)
  556. else:
  557. left_pos = title_block['page_number']
  558. # 最终判定
  559. if left_pos != -1:
  560. search_interval.append((left_pos, right_pos))
  561. # 重置
  562. left_pos = -1
  563. right_pos = -1
  564. # locate in outlines.json
  565. if len(self.outline) > 0:
  566. for outline_block in self.outline:
  567. if left_pos != -1:
  568. right_pos = outline_block["page_number"]
  569. right_pos = right_pos if right_pos is not None else -1
  570. search_interval.append((left_pos, right_pos))
  571. left_pos = -1
  572. outline_text = outline_block['title'].strip()
  573. for keyword in keywords:
  574. if keyword in outline_text:
  575. if outline_block["page_number"] is not None:
  576. left_pos = outline_block["page_number"]
  577. # 最终判定
  578. if left_pos != -1:
  579. search_interval.append((left_pos, right_pos))
  580. # 搜寻区间合并
  581. search_interval.sort()
  582. merge_interval = []
  583. if len(search_interval) > 0:
  584. left = -1
  585. right = -1
  586. for interval in search_interval:
  587. l, r = interval
  588. if r < l:
  589. continue
  590. if left == -1 and right == -1:
  591. left = l
  592. right = r
  593. elif l <= right:
  594. right = r
  595. else:
  596. merge_interval.append((left, right))
  597. left = l
  598. right = r
  599. merge_interval.append((left, right))
  600. return merge_interval
  601. def find_candidate_images(self):
  602. candidate_images = set()
  603. merge_intervals = self.search_interval()
  604. for interval in merge_intervals:
  605. start_page, end_page = interval
  606. if start_page <= self.start_threshold:
  607. continue
  608. if end_page == -1:
  609. end_page = start_page + 20
  610. candidate_images = self.image_regularization(start_page=max(0, start_page-self.search_threshold), end_page=end_page+self.search_threshold, candidate_images=candidate_images)
  611. candidate_images = list(candidate_images)
  612. return candidate_images
  613. # 定位营业执照图像
  614. def locate_business_license(self):
  615. '''locate business license and return image'''
  616. keywords = ["资格审查资料", "其它资格审查材料", "资格审查材料"]
  617. candidate_pages = []
  618. center_pages = []
  619. candidate_images = set()
  620. # locate in title.json
  621. for title_block in self.title:
  622. block_text = title_block['text'].replace(' ', '').strip()
  623. for keyword in keywords:
  624. if keyword in block_text:
  625. # 先进行模糊的outline定位
  626. center_page = None
  627. if '.' in block_text:
  628. center_page = block_text.split('.')[-1]
  629. if center_page.isdigit():
  630. center_page = eval(center_page)
  631. center_pages.append(center_page)
  632. candidate_pages.append(title_block['page_number'])
  633. # locate in outlines.json
  634. if len(self.outline) > 0:
  635. for outline_block in self.outline:
  636. outline_text = outline_block['title'].strip()
  637. for keyword in keywords:
  638. if keyword in outline_text:
  639. center_pages.append(outline_block["page_number"])
  640. # information match
  641. filter_pages = set()
  642. if len(center_pages) == 0 and len(candidate_pages) == 0:
  643. return None
  644. elif len(center_pages) == 0:
  645. filter_pages.update(candidate_pages)
  646. elif len(candidate_pages) == 0:
  647. filter_pages.update(center_pages)
  648. else:
  649. # center_pages作为锚点,全部加入
  650. filter_pages.update(center_pages)
  651. # candidate_page与center_page进行匹配加入
  652. for candidate_page in candidate_pages:
  653. if candidate_page <= self.start_threshold:
  654. continue
  655. for center_page in center_pages:
  656. distance = abs(candidate_page - center_page)
  657. if distance <= self.distance_threshold:
  658. filter_pages.add(min(candidate_page, center_page) + distance // 2)
  659. # 得到筛选后的图片集存储于self.candidate_images
  660. for filter_page in filter_pages:
  661. # candidate_images = self.image_regularization(candidate_images=candidate_images, start_page=max(filter_page-self.search_threshold, 0), end_page=filter_page+self.search_threshold)
  662. candidate_images = self.image_regularization(start_page=max(filter_page-self.search_threshold, 0), end_page=filter_page+self.search_threshold, candidate_images=candidate_images)
  663. # 获取最终图像的地址
  664. candidate_images = list(candidate_images)
  665. target_list = self.exact_match(candidate_images=candidate_images)
  666. # return target_path list
  667. return target_list
  668. # 定位资质证书
  669. def locate_qualtification_certificate(self):
  670. '''返回资质证书的图像列表'''
  671. # 通过关键字模糊定位
  672. keywords = ['资格审查资料','资格审查材料','其它材料','其他材料','影印件']
  673. search_interval = []
  674. candidate_images = set()
  675. # locate in title.json
  676. left_pos = -1 # 左指针
  677. right_pos = -1 # 右指针
  678. for title_block in self.title:
  679. block_text = title_block['text'].replace(' ', '').strip()
  680. # 先进行左区间判定
  681. if left_pos != -1 and '证书' not in block_text:
  682. right_pos = title_block['page_number']
  683. search_interval.append((left_pos, right_pos))
  684. # 重置
  685. left_pos = -1
  686. for keyword in keywords:
  687. if keyword in block_text:
  688. # 先进行模糊的outline定位
  689. center_page = None
  690. if '.' in block_text:
  691. center_page = block_text.split('.')[-1]
  692. if center_page.isdigit():
  693. center_page = eval(center_page)
  694. left_pos = min(title_block['page_number'], center_page)
  695. else:
  696. left_pos = title_block['page_number']
  697. # 最终判定
  698. if left_pos != -1:
  699. search_interval.append((left_pos, right_pos))
  700. # 重置
  701. left_pos = -1
  702. right_pos = -1
  703. # locate in outlines.json
  704. if len(self.outline) > 0:
  705. for outline_block in self.outline:
  706. if left_pos != -1:
  707. right_pos = outline_block["page_number"]
  708. right_pos = right_pos if right_pos is not None else -1
  709. search_interval.append((left_pos, right_pos))
  710. left_pos = -1
  711. outline_text = outline_block['title'].strip()
  712. for keyword in keywords:
  713. if keyword in outline_text:
  714. if outline_block["page_number"] is not None:
  715. left_pos = outline_block["page_number"]
  716. # 最终判定
  717. if left_pos != -1:
  718. search_interval.append((left_pos, right_pos))
  719. # 搜寻区间合并
  720. search_interval.sort()
  721. merge_interval = []
  722. if len(search_interval) > 0:
  723. left = -1
  724. right = -1
  725. for interval in search_interval:
  726. l, r = interval
  727. if r < l:
  728. continue
  729. if left == -1 and right == -1:
  730. left = l
  731. right = r
  732. elif l <= right:
  733. right = r
  734. else:
  735. merge_interval.append((left, right))
  736. left = l
  737. right = r
  738. merge_interval.append((left, right))
  739. for interval in merge_interval:
  740. start_page, end_page = interval
  741. if end_page == -1:
  742. end_page = start_page + 20
  743. if start_page <= self.start_threshold:
  744. continue
  745. candidate_images = self.image_regularization(start_page=max(0, start_page-self.search_threshold), end_page=end_page+self.search_threshold, candidate_images=candidate_images)
  746. candidate_images = list(candidate_images)
  747. target_list = self.search_qualtification_certificate(candidate_images=candidate_images)
  748. return target_list
  749. # 查询符合格式的图像
  750. def image_regularization(self, start_page: int, end_page:int, candidate_images: set):
  751. for index in range(start_page, end_page + 1):
  752. current_format = self.image_format.format(index)
  753. files = glob.glob(os.path.join(self.image_dir, current_format))
  754. # cut_files = list(map(lambda x: x.split('/')[-1], files))
  755. # filter_files = [file for file in cut_files if not file.endswith('.unk')]
  756. filter_files = [file for file in files if not file.endswith('.unk')]
  757. candidate_images.update(filter_files)
  758. return candidate_images
  759. def exact_match(self, candidate_images: list):
  760. '''精确匹配营业执照位置'''
  761. if len(candidate_images) == 0:
  762. return None
  763. target_list = []
  764. sim_list = []
  765. for image_path in candidate_images:
  766. score = self.get_similarity(image_path=image_path, tamplate=self.bl_tamplate)
  767. sim_list.append(score.cpu().numpy())
  768. # top-k > match_threshold
  769. sim_list = np.array(sim_list).reshape(len(sim_list))
  770. for i, cos_sim in enumerate(sim_list):
  771. if cos_sim > self.match_threshold:
  772. target_list.append(candidate_images[i])
  773. # 未找寻到符合当前阈值要求的图像,降低阈值
  774. if len(target_list) == 0:
  775. for i, cos_sim in enumerate(sim_list):
  776. if cos_sim > self.degrade_threshold:
  777. target_list.append(candidate_images[i])
  778. return target_list
  779. def search_qualtification_certificate(self, candidate_images: list):
  780. '''从candidate images中搜寻是否有符合资质证书的图像'''
  781. if len(candidate_images) == 0:
  782. return None
  783. target_list = []
  784. sim_list = []
  785. for image_path in candidate_images:
  786. score = self.get_similarity(image_path=image_path, tamplate=self.qc_tamplate)
  787. sim_list.append(score.cpu().numpy())
  788. sim_list = np.array(sim_list).reshape(len(sim_list))
  789. for i, cos_sim in enumerate(sim_list):
  790. if cos_sim > self.qc_threshold:
  791. target_list.append(candidate_images[i])
  792. return target_list
  793. def get_similarity(self, image_path, tamplate):
  794. image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
  795. text = clip.tokenize([tamplate]).to(device)
  796. with torch.no_grad():
  797. logits_per_image, logits_per_text = model.get_similarity(image, text)
  798. return logits_per_image
  799. if __name__ == '__main__':
  800. start_time = time.time()
  801. url = "http://120.48.103.13:18000/ctr_ocr"
  802. base_url = "https://aip.baidubce.com/rest/2.0/ocr/v1/seal?access_token="
  803. access_token = "24.6bbe9987c6bd19ba65e4402917811657.2592000.1724573148.282335-86574608"
  804. headers = {'content-type': 'application/x-www-form-urlencoded'}
  805. data_path = "/home/stf/miner_pdf/data/投标公司pdf"
  806. out_path = "/home/stf/miner_pdf/test.json"
  807. ground_truth = "/home/stf/miner_pdf/ground_truth.json"
  808. firm_excel_file = "/home/stf/miner_pdf/data/certificate.xlsx"
  809. df = pd.read_excel(firm_excel_file)
  810. ocr = OcrAgent(url=url)
  811. seal_ocr = seal_agent(base_url=base_url, access_token=access_token, headers=headers)
  812. unscanned_firm_list = df[(df['是否为扫描件'] == '否')]['公司名称'].tolist()
  813. scanned_firm_list = df[(df['是否为扫描件'] == '是')]['公司名称'].tolist()
  814. all_firm_list = unscanned_firm_list + scanned_firm_list
  815. data = {}
  816. start_time = time.time()
  817. test_img_dir = "/home/stf/miner_pdf/test_img"
  818. for img in tqdm(os.listdir(test_img_dir)):
  819. logger.info(f'processing {img} ...')
  820. img_path = os.path.join(test_img_dir, img)
  821. result = ocr.signature_recognition(img_path)
  822. logger.info(f'识别结果: {result}')