File size: 4,205 Bytes
fc60fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42a09ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc60fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66ebefb
 
 
 
 
 
 
 
 
 
fc60fc6
 
 
 
 
66ebefb
fc60fc6
 
66ebefb
fc60fc6
 
 
 
 
 
 
 
 
66ebefb
fc60fc6
 
66ebefb
fc60fc6
 
 
 
 
 
 
 
66ebefb
fc60fc6
 
66ebefb
fc60fc6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import pipeline
import pandas as pd

class LayoutLM:

    def __init__(self, save_pretrained_fpath:str=None) -> None:
        self.pipeline_category = 'document-question-answering'
        self.tf_pipeline = pipeline
        self.pipeline = None

        if save_pretrained_fpath is not None:
            pipe = self.tf_pipeline(self.pipeline_category)
            pipe.save_pretrained(save_pretrained_fpath)

        self.default_model = 'impira/layoutlm-invoices'
        self.default_ex_answer = {'score':0, 'answer':'-'}

    def set_model(self, model:str):
        if model is None:
            model = self.default_model

        self.pipeline = self.tf_pipeline(self.pipeline_category, model=model)

    def answer_the_question_without_filter(self, img, question: str, is_debug=False, **kwargs):
        answers = None

        top_k = kwargs['top_k'] if kwargs.get('top_k') is not None else 1
        max_answer_len = kwargs['max_answer_len'] if kwargs.get('max_answer_len') is not None else 15

        if self.pipeline is not None:
            answers = self.pipeline(img, question, 
                top_k=top_k, 
                max_answer_len=max_answer_len)

        if is_debug:
            print('--------------------')
            print(answers)

        return answers
    
    def answer_the_question(self, img, question: str, is_debug=False):
        score = 0
        answer = '-'
        answers = None
        if self.pipeline is not None:
            answers = self.pipeline(img, question)

            for a in answers:
                if a['score'] > score:
                    score = a['score']
                    answer = a['answer']

        if is_debug:
            print('--------------------')
            print(f'Q: {question}\nA: {answer} (acc:{score:.2f})\n')
            print(answers)

        return answer

    def inference(self, img, is_debug=False):
        merchant_id = self.answer_the_question(img, 'What is merchant ID?', is_debug=is_debug)
        merchant_name = self.answer_the_question(img, 'What is merchant name?', is_debug=is_debug)
        merchant_address = self.answer_the_question(img, 'What is merchant address?', is_debug=is_debug)
        merchant_branch = self.answer_the_question(img, 'What is branch of merchant?', is_debug=is_debug)
        invoice_no = self.answer_the_question(img, 'What is invoice number or INV?', is_debug=is_debug)
        products = self.answer_the_question(img, 'What are buy products?', is_debug=is_debug)
        product_codes = self.answer_the_question(img, 'What are code of buy products?', is_debug=is_debug)
        pos_no = self.answer_the_question(img, 'What is POS number?', is_debug=is_debug)
        net_price = self.answer_the_question(img, 'What is the net-price?', is_debug=is_debug)
        date_time = self.answer_the_question(img, 'What date, year and time of the invoice?', is_debug=is_debug)

        if is_debug:
            print(f'Merchant ID: {merchant_id}')
            print(f'Merchant name: {merchant_name}')
            print(f'Merchant address: {merchant_address}')
            print(f'Merchant branch: {merchant_branch}')
            print(f'Invoice no.: {invoice_no}')
            print(f'Products: {products}')
            print(f'Product codes: {product_codes}')
            print(f'POS no.: {pos_no}')
            print(f'Net price: {net_price}')
            print(f'Date/Time: {date_time}')

        return pd.DataFrame({
            'Data' : [
                'Merchant ID', 
                'Merchant name', 
                'Merchant address', 
                'Merchant branch', 
                'Invoice no.',
                'Products',
                'Product codes',
                'POS no.',
                'Net price',
                'Date/Time'
            ], 
            'Value' : [
                str(merchant_id),
                str(merchant_name),
                str(merchant_address),
                str(merchant_branch),
                str(invoice_no),
                str(products),
                str(product_codes),
                str(pos_no),
                str(net_price),
                str(date_time)
            ]
        })