ソースを参照

merge zzh branch

sprivacy 1 年間 前
コミット
48b9a3e9b3
6 ファイル変更540 行追加0 行削除
  1. 158 0
      extract_financial_report.py
  2. 70 0
      extract_price.py
  3. 106 0
      instance_locate.py
  4. 42 0
      ocr_api.py
  5. 28 0
      scan_dir.py
  6. 136 0
      text_extractor.py

+ 158 - 0
extract_financial_report.py

@@ -0,0 +1,158 @@
+import os
+import re
+import time
+from re import match
+
+from tqdm import tqdm
+from scan_dir import scan_dir
+from instance_locate import get_instances_by_title
+from ocr_api import OcrAgent, find_current_row
+import datetime
+
+
+def is_price(word: str) -> bool:
+    pattern = (
+        r"(?:\b(?:[BS]/\.|R(?:D?\$|p))|\b(?:[TN]T|[CJZ])\$|Дин\.|\b(?:Bs|Ft|Gs|K[Mč]|Lek|B[Zr]|k[nr]|[PQLSR]|лв"
+        r"|ден|RM|MT|lei|zł|USD|GBP|EUR|JPY|CHF|SEK|DKK|NOK|SGD|HKD|AUD|TWD|NZD|CNY|KRW|INR|CAD|VEF|EGP|THB|IDR"
+        r"|PKR|MYR|PHP|MXN|VND|CZK|HUF|PLN|TRY|ZAR|ILS|ARS|CLP|BRL|RUB|QAR|AED|COP|PEN|CNH|KWD|SAR)|\$[Ub]|"
+        r"[^\w\s])\s?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d{1,2})?(?!\.?\d)"
+    )
+    char_set = set('1234567890,.')
+    if re.fullmatch(pattern, word):
+        return True
+    elif sum([0 if s in char_set else 1 for s in word]) == 0:
+        return True
+    else:
+        return False
+
+
+def extract_financial_report(path: str, year: int = None):
+    instances = get_instances_by_title(path,
+                                       ['财务状况', '{}年审计报告'.format(year - 1), '{}年审计报告'.format(year - 2)])
+    results = []
+    ocr_agent = OcrAgent("http://120.48.103.13:18000/ctr_ocr")
+    for item in instances:
+        if item['tables']:
+            table_name = [t['table_name'] for t in item['tables']]
+            profits = []
+            for table in item['tables']:
+                profit = []
+                for row in table['table']:
+                    if list(filter(lambda x: match(r'.*利润.*', x) is not None, row)):
+                        profit.append(row)
+                profits.append(profit)
+            results.append({
+                'title': table_name,
+                'result': profits,
+                'pages': [i['page_numbers'] for i in item['tables']],
+                'chapter': item['title']
+            })
+        elif item['page_number'] >= item['end_page']:
+            print('Wrong titles extracted at {}'.format(item['title']))
+        else:
+            images = list(filter(
+                lambda x: (item['page_number'] <= int(x.split('_')[2]) <= item['end_page'])
+                          and (x.endswith('.jpg') or x.endswith('.png'))
+                          and os.path.isfile(os.path.join(item['image_loc'], x)),
+                os.listdir(item['image_loc']))
+            )
+            # for image in images:
+            #     ocr = table_pic_ocr(os.path.join(item['image_loc'], image))
+            #     pass
+            '''paddleOCR abandoned
+            ocr_results = table_pic_ocr_batch([os.path.join(item['image_loc'], image) for image in images])
+            candidate = []
+            for i in range(len(images)):
+                page = images[i]
+                for data in ocr_results[i]:
+                    if data['type'] in ('header', 'footer', 'table_caption', 'figure_caption', 'title'):
+                        for text in data['res']:
+                            if '利润' in text['text']:
+                                candidate.append(page)
+                                break
+                    elif data['type'] in ('text', 'figure'):
+                        for text in data['res']:
+                            if '净利润' in text['text']:
+                                candidate.append(page)
+                                break
+                    elif data['type'] in ('table',):
+                        table = pd.read_html(data['res']['html'])[0].values.tolist()
+                        for row in table:
+                            if '净利润' in ''.join([str(i) for i in row]):
+                                candidate.append(page)
+                                break
+                    else:
+                        for text in data['res']:
+                            if '净利润' in text['text']:
+                                candidate.append(page)
+                                break
+
+            '''
+            print('未找到表格 图片识别中')
+            ocr_results = [ocr_agent.get_content(os.path.join(item['image_loc'], i))['rawjson']['ret'] for i in
+                           tqdm(images)]
+            candidate = []
+            rows = []
+            print('结果分析中')
+            for i, ret in tqdm(enumerate(ocr_results)):
+                for res in ret:
+                    if re.match(r'.*(净利润).*', res['word']) is not None:
+                        top = res['rect']['top']
+                        bottom = res['rect']['top'] - res['rect']['height']
+                        candidate.append(
+                            {
+                                'page': images[i],
+                                'text': res['word'],
+                                'top': top,
+                                'bottom': bottom,
+                            }
+                        )
+                        rows.append(find_current_row(ret, top, bottom))
+            for it in candidate:
+                print('定位:\t{}\t定位词:\t{}'.format(it['page'], it['text']))
+
+            for i, row in enumerate(rows):
+                title = []
+                profits = []
+                for w in row:
+                    if is_price(w['word']):
+                        profits.append(w['word'])
+                    else:
+                        title.append(w['word'])
+                if title and profits:
+                    results.append({
+                        'chapter': item['title'],
+                        'page': candidate[i]['page'],
+                        'title': title,
+                        'result': profits
+                    })
+            pass
+        pass
+
+    return results
+
+
+if __name__ == '__main__':
+    # print(extract_financial_report('./投标文件-修改版9-5-1-1.pdf'))
+
+    os.environ["TRANSFORMERS_OFFLINE"] = '1'
+
+    y = datetime.datetime.now().year
+    print(extract_financial_report(
+        '/home/zzh/ocr/pdf/美华建设有限公司/投标文件111.pdf',
+        # '/home/zzh/ocr/pdf/南方电网数字电网研究院有限公司/南方电网数字研究院有限公司.pdf',
+        # '/home/zzh/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+        2022
+    ))
+    # start = time.time()
+    # fs = scan_dir('/home/zzh/ocr/pdf/', 'pdf')
+    #
+    # for f in fs:
+    #     try:
+    #         print(f)
+    #         print(extract_financial_report(f, 2022))
+    #         print('\n*********Runtime {} s*********\n'.format(time.time() - start))
+    #     except:
+    #         print('Something wrong')
+    #
+    # print('\n\n{}'.format(time.time() - start))

+ 70 - 0
extract_price.py

@@ -0,0 +1,70 @@
+from re import findall
+from typing import List
+
+from text_extractor import get_instance
+
+
+def rmb_to_digit(rmb_str):
+    digit_map = {'零': 0, '壹': 1, '贰': 2, '叁': 3, '肆': 4, '伍': 5, '陆': 6, '柒': 7, '捌': 8, '玖': 9}
+    unit_map = {'分': 0.01, '角': 0.1, '元': 1, '拾': 10, '佰': 100, '仟': 1000, '万': 10000, '亿': 100000000}
+
+    digit = 0
+    total = 0
+    tmp = 0
+    for char in rmb_str:
+        if char in digit_map:
+            digit = digit_map[char]
+        elif char in unit_map:
+            if digit + tmp:
+                total += (tmp + digit) * unit_map[char]
+                tmp = digit = 0
+            else:
+                total *= unit_map[char]
+        else:
+            tmp = digit
+    total += tmp + digit
+    return '{:.2f}'.format(total)
+
+
+def match_price_zhs(text: str) -> List[str]:
+    pattern = (r"[壹,贰,叁,肆,伍,陆,柒,捌,玖,拾,佰,仟][壹,贰,叁,肆,伍,陆,柒,捌,玖,拾,佰,仟,元,角,万,分,百,整,零]+"
+               r"[壹,贰,叁,肆,伍,陆,柒,捌,玖,拾,佰,仟,元,角,万,分,百,整,零]")
+    temp = findall(pattern, text)
+    for i in range(len(temp)):
+        if temp[i].endswith('整元') or temp[i].endswith('角元') or temp[i].endswith('分元') or temp[i].endswith('元元'):
+            temp[i] = temp[i][:-1]
+    return temp
+
+
+def match_price_num(text: str) -> List[str]:
+    pattern = (r"(?:\b(?:[BS]/\.|R(?:D?\$|p))|\b(?:[TN]T|[CJZ])\$|Дин\.|\b(?:Bs|Ft|Gs|K[Mč]|Lek|B[Zr]|k[nr]|[PQLSR]|лв|"
+               r"ден|RM|MT|lei|zł|USD|GBP|EUR|JPY|CHF|SEK|DKK|NOK|SGD|HKD|AUD|TWD|NZD|CNY|KRW|INR|CAD|VEF|EGP|THB|IDR|"
+               r"PKR|MYR|PHP|MXN|VND|CZK|HUF|PLN|TRY|ZAR|ILS|ARS|CLP|BRL|RUB|QAR|AED|COP|PEN|CNH|KWD|SAR)|\$[Ub]|"
+               r"[^\w\s])\s?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d{1,2})?(?!\.?\d)")
+    return findall(pattern, text)
+
+
+def match_duration(text: str) -> List[str]:
+    pattern = r"[1-9]+[\d]日历天"
+    return findall(pattern, text)
+
+
+def match_quality(text: str) -> List[str]:
+    pattern = r"工程质量.+"
+    return findall(pattern, text)
+
+
+if __name__ == '__main__':
+    price_zhs = get_instance(['投标函', '开标一览表'], ['人民币投标总报价'],
+                             '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+                             match_price_zhs)
+    price_num = get_instance(['投标函', '开标一览表'], ['人民币投标总报价'],
+                             '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+                             match_price_num)
+    duration = get_instance(['投标函', '开标一览表'], ['工期日历天'],
+                            '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+                            match_duration)
+    quality = get_instance(['投标函', '开标一览表'], ['工程质量'],
+                           '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+                           match_quality)
+    valid = rmb_to_digit(price_zhs[0][0][0]) == price_num[0][0][0][1:]

+ 106 - 0
instance_locate.py

@@ -0,0 +1,106 @@
+from typing import List
+from pdfminer.high_level import extract_pages
+from pdfminer.layout import LTFigure, LTImage, LTTextBoxHorizontal
+from pprint import pprint
+
+from tqdm import tqdm
+
+from text_extractor import similarity_filter, similar_match, parse_title
+from get_info import PdfExtractAttr, export_image
+import os
+import json
+
+os.environ['TRANSFORMERS_OFFLINE'] = '1'
+
+
+def parse_pages(pdf_path: str, text_path: str, image_dir: str, start_page: int, end_page: int, total_page: int) -> None:
+    # 用于存储文本和图像
+    texts = []
+    images = []
+    # 读取PDF文件并提取页面
+    # 调用pdfminer中的extract_page函数提取每一页的页面布局page_layout
+    for page_number, page_layout in tqdm(enumerate(extract_pages(pdf_path)), total=total_page):
+        if not start_page <= page_number <= end_page:
+            continue
+        title_index = 0
+        image_index = 0
+        # 遍历页面布局中的每一个元素
+        for element in page_layout:
+            if isinstance(element, LTFigure):
+                for e_obj in element._objs:
+                    if isinstance(e_obj, LTImage):
+                        # 提取图片数据
+                        image_file = os.path.join(image_dir, f'image_page_{page_number}_{image_index}')
+                        image_file = export_image(e_obj, image_file)
+                        images.append(image_file)
+                        # pprint(f'Image saved: {image_file}')
+                        image_index += 1
+            elif isinstance(element, LTTextBoxHorizontal) and len(element._objs) == 1:
+                # 提取文本
+                text = element.get_text().strip()
+                # # 假设标题通常是一行且字体较大
+                texts.append({'index': title_index, 'page_number': page_number, 'bbox': element.bbox, 'text': text})
+                title_index += 1
+    # 最终将标题信息保存为本地的json文件
+    with open(text_path, 'w', encoding='utf-8') as fp:
+        json.dump(texts, fp, indent=4, ensure_ascii=False)
+
+
+def get_instances_by_title(path: str, instances: List[str]):
+    """
+    Get all tables and figures of given title
+    """
+
+    # path = './投标文件-修改版9-5-1-1.pdf'
+    # instances = ['近年财务状况表']
+    file = PdfExtractAttr(file_path=path)
+    print('解析PDF文字中')
+    file.parse_text()
+    # title = file.parse_outline()
+    print('解析PDF标题中')
+    all_title = parse_title(path)
+    # all_text = file.parse_text()  # remain for external parse
+
+    print('分析标题中')
+    title_sims = similarity_filter(similar_match(all_title, instances, key='title'), 0.5)
+    title_f = [i for i in title_sims]
+    results = []
+    for i in title_f:
+        try:
+            i['end_page'] = all_title[i['seq_num'] + 1]['page_number'] - 1
+            if i['end_page'] <= i['page_number']:
+                continue
+            # i['end_page'] = all_title[i['seq_num']]['page_number'] + 5  # for debug
+        except IndexError:
+            i['end_page'] = float('inf')
+
+        image_loc = os.path.join(os.path.dirname(path), 'images')
+        if not os.path.exists(image_loc):
+            os.makedirs(image_loc)
+        print('解析标题:\t{}'.format(i['title']))
+        print('解析图片中')
+        parse_pages(path, os.path.join(os.path.dirname(path),
+                                       '{}_texts_{}_{}.json'.format(i['title'], i['page_number'], i['index'])),
+                    image_loc, i['page_number'], i['end_page'], file.total_page)
+
+        table_loc = os.path.join(os.path.dirname(path),
+                                 '{}_tables_{}_{}.json'.format(i['title'], i['page_number'], i['index']))
+        print('解析表格中')
+        tables = file.parse_table(start=i['page_number'], end=i['end_page'])
+        i['tables'] = tables
+        with open(table_loc, 'w', encoding='utf-8') as fp:
+            json.dump(tables, fp, indent=4, ensure_ascii=False)
+        i.update({'table_loc': table_loc, 'image_loc': image_loc})
+        results.append(i)
+
+    return results
+
+
+'''
+大标题 outlines
+小标题 text
+表/图
+1. 文字 + 表格(取第一行为标题)
+2. 文字 + 图片(取第一行为标题)
+3. 纯图片、表格(向上合并)
+'''

+ 42 - 0
ocr_api.py

@@ -0,0 +1,42 @@
+# ocr外部接口
+import os
+from typing import List
+
+from requests import post
+
+
+class OcrAgent:
+    def __init__(self, url):
+        self.url = url
+
+    def get_content(self, image_path):
+        try:
+            with open(image_path, 'rb') as image_file:
+                files = {"file": ("image.jpg", image_file, "image/jpeg")}
+                # files = {"file": ("image.png", image_file, "image/png")}
+                response = post(self.url, files=files)
+            return response.json()
+        except ValueError:
+            raise ValueError(f"传入图像{image_path}已损坏")
+
+
+def find_current_row(ocr_result: List[dict], top: int, bottom: int, float_range: int = 5):
+    results = []
+    assert float_range >= 0
+    top += float_range
+    bottom -= float_range
+    for ret in ocr_result:
+        ct = ret['rect']['top']
+        cb = ret['rect']['top'] - ret['rect']['height']
+        if top >= ct > cb >= bottom:
+            results.append(ret)
+    return results
+
+
+
+
+if __name__ == '__main__':
+    agent = OcrAgent("http://120.48.103.13:18000/ctr_ocr")
+    res = agent.get_content(
+        os.path.join('/home/zzh/ocr/pdf/南方电网数字电网研究院有限公司/images', 'image_page_1131_0.png'))
+    pass

+ 28 - 0
scan_dir.py

@@ -0,0 +1,28 @@
+import os
+from typing import List
+
+
+def scan_dir(path, suffix: str = None):
+    results = []
+    files = os.listdir(path)
+
+    for file in files:
+        file_d = os.path.join(path, file)
+        if os.path.isdir(file_d):
+            results.extend(scan_dir(file_d, suffix))
+        else:
+            if (suffix and file.endswith(suffix)) or (not suffix):
+                results.append(file_d)
+    return results
+
+
+def batch_ln(files: List[str], target: str):
+    cmd = 'ln -s {} {}'
+    for f in files:
+        if os.path.isfile(f):
+            os.system(cmd.format(f, os.path.join(target, os.path.basename(f))))
+
+
+if __name__ == '__main__':
+    fs = scan_dir('/home/zzh/ocr/pdf', 'pdf')
+    batch_ln(fs, './all_pdf')

+ 136 - 0
text_extractor.py

@@ -0,0 +1,136 @@
+from pdfminer.high_level import extract_pages
+from pdfminer.layout import LTTextBoxHorizontal
+from pdfminer.pdfinterp import resolve1
+from pdfminer.pdfdocument import PDFDocument
+from pdfminer.pdfparser import PDFParser
+from matcher import Matcher
+from get_info import PdfExtractAttr, is_title
+from typing import Callable, Union, List, Tuple, Dict
+from re import fullmatch
+from tqdm import tqdm
+import pandas as pd
+
+
+def absolute_not_title(line: str) -> bool:
+    if fullmatch(r'^\d(\d*\.?\d*)+\d(%)?', line):
+        return True
+    else:
+        return False
+
+
+def parse_title(pdf_path: str) -> list[dict[str, int | str | tuple[float, float, float, float]]]:
+    texts = []
+    for page_number, page_layout in tqdm(enumerate(extract_pages(pdf_path)),
+                                         total=resolve1(PDFDocument(
+                                             PDFParser(open(pdf_path, 'rb'))).catalog['Pages'])['Count']
+                                         ):
+        title_index = 0
+        for element in page_layout:
+            if isinstance(element, LTTextBoxHorizontal) and len(element._objs) == 1:
+                text = element.get_text().strip()
+                if text and (is_title(text) or element.height > 15) and (not absolute_not_title(text)):
+                    texts.append({'index': title_index, 'page_number': page_number, 'bbox': element.bbox, 'text': text})
+                    title_index += 1
+    results = []
+    for i, text in enumerate(texts):
+        results.append({'title': text['text'],
+                        'index': text['index'],
+                        'page_number': text['page_number'],
+                        'seq_num': i
+                        })
+    return results
+
+
+def pagination_texts(contents: List[dict], start: int, end: int = None) -> Tuple[Dict, List[str]]:
+    if end is None:
+        end = start + 1
+    results = {}
+    texts = []
+    pages = set(range(start, end))
+    for page in contents:
+        if page['page_number'] in pages:
+            results.get(int(page['page_number']), {}).update(
+                {
+                    page['index']: {
+                        'page_number': page['page_number'],
+                        'index': page['index'],
+                        'text': page['text'],
+                        'lines': page['lines'],
+                        'is_table_name': page['is_table_name']
+                    }
+                })
+            texts.append(page['text'])
+    return results, texts
+
+
+def similarity_filter(data: List[dict], expect_similarity: float = None):
+    def f(x: dict):
+        return x['相似度'] > (expect_similarity if isinstance(expect_similarity, float) else 0.5)
+
+    return filter(f, data)
+
+
+def extract_from_texts(text: List[str], extractor: Union[Callable[[str, float], List[str]], Callable[[str], List[str]]],
+                       instances: List[str], similarity: float = None) -> Tuple[List[str], List[int]]:
+    texts = ','.join(filter(lambda x: x != '',
+                            ''.join([''.join(filter(lambda x: x != ' ', list(i.strip()))) for i in text]).split(
+                                '。'))).split(',')
+    sims = similar_match([{'text': i} for i in texts], instances, 'text')
+    s_texts = [i['text'] for i in sims]
+    similarities = [i['相似度'] for i in sims]
+    if similarity is None:
+        return list(filter(lambda x: x != [], [extractor(i) for i in s_texts])), similarities
+    else:
+        return list(filter(lambda x: x != [], [extractor(i, similarity) for i in s_texts])), similarities
+
+
+def similar_match(data: List[dict], instances: List[str], key: str) -> {}:
+    matcher = Matcher()
+    df = pd.DataFrame(data)
+    keyword_embeddings = matcher.get_embeddings(instances)
+    tqdm.pandas(desc='标题相似度匹配')
+    result = df[key].apply(lambda x: matcher.TopK1(x, instances, matcher.get_embedding(x), keyword_embeddings))
+    result.columns = ['因素', '相似度']
+
+    df['因素'] = result['因素']
+    df['相似度'] = result['相似度']
+
+    max_sim_idx = df.groupby('因素')['相似度'].idxmax()
+    max_sim_rows = df.loc[max_sim_idx]
+    return max_sim_rows.to_dict(orient='records')
+
+
+def get_instance(title_instances: List[str], content_instances: List[str], pdf: str,
+                 extractor: Union[Callable[[str, float], List[str]], Callable[[str], List[str]]],
+                 page_bias: int = 1, similarity: float = None):
+    file = PdfExtractAttr(file_path=pdf)
+    # titles = file.parse_outline()
+    titles = parse_title(pdf)
+    texts = file.parse_text()
+
+    title_sims = similarity_filter(similar_match(titles, title_instances, key='title'), similarity)
+    results = []
+    for i in title_sims:
+        current_page = i['page_number']
+        _, text = pagination_texts(texts, current_page, current_page + page_bias)
+        results.extend(extract_from_texts(text, extractor, content_instances))
+    return results
+
+
+if __name__ == '__main__':
+    # price_zhs = get_instance(['投标函', '开标一览表'], ['人民币投标总报价'],
+    #                          '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+    #                          match_price_zhs)
+    # price_num = get_instance(['投标函', '开标一览表'], ['人民币投标总报价'],
+    #                          '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+    #                          match_price_num)
+    # duration = get_instance(['投标函', '开标一览表'], ['工期日历天'],
+    #                         '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+    #                         match_duration)
+    # quality = get_instance(['投标函', '开标一览表'], ['工程质量'],
+    #                        '/Users/zelate/Codes/pvas/pdf_title_image/投标文件-修改版9-5-1-1.pdf',
+    #                        match_quality)
+    # valid = rmb_to_digit(price_zhs[0][0][0]) == price_num[0][0][0][1:]
+    # test = rmb_to_digit('壹仟肆佰贰拾万捌仟玖佰陆拾柒元叁角陆分元')
+    # valid = (rmb_to_digit('壹仟肆佰贰拾万捌仟玖佰陆拾柒元叁角陆分元')) == '14208967.36'
+    pass