File size: 2,863 Bytes
a8af44f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Spaces
Posts
Docs
Pricing



Spaces:

junhyun01
/
Onpremise_LLM_Normal_Detection

private

App
Files
Community
1
Settings
Onpremise_LLM_Normal_Detection
/
app.py
junhyun01's picture
junhyun01
Update app.py
18d8c3a
VERIFIED
24 days ago
raw
history
blame
edit
delete
No virus
2.5 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
from torch import nn
import re


def paragraph_leveling(text):
    model_name = "./trained_model/fine_tunning_encoder_v2"
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')

    class MLP(nn.Module):
        def __init__(self, target_size=3, input_size=768):
            super(MLP, self).__init__()
            self.num_classes = target_size
            self.input_size = input_size
            self.fc1 = nn.Linear(input_size, target_size)

        def forward(self, x):
            out = self.fc1(x)
            return out

    classifier = MLP(target_size=3, input_size=768)
    classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu')))
    classifier.eval()

    output_list = []
    text_list = text.split(".")
    result = []

    output_list.append(("\n", None))

    for idx_sentence in text_list:
        train_encoding = tokenizer(
            idx_sentence,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=120)
        output = model(**train_encoding)
        output = classifier(output[1])
        output = output[0]

        if output.argmax(-1) == 0:
            output_list.append((idx_sentence, 'abnormal'))
            result.append(0)
        elif output.argmax(-1) == 1:
            output_list.append((idx_sentence, 'normal'))
            result.append(1)
        else:
            output_list.append((idx_sentence, 'uncertain'))
            result.append(2)

    output_list.append(('\n', None))
    if 0 in result:
        output_list.append(('FINAL LABEL: ', None))
        output_list.append(('ABNORMAL', 'abnormal'))

    else:
        output_list.append(('FINAL LABEL: ', None))
        output_list.append(('NORMAL', 'normal'))

    return output_list


demo = gr.Interface(
    paragraph_leveling,
    [
        gr.Textbox(
            label="Medical Report",
            info="You can put any types of medical report",
            lines=20,
            value=" ",
        ),
    ],
    gr.HighlightedText(
        label="labeling",
        show_legend = True,
        show_label = True,
        color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}),
    theme=gr.themes.Base()
)
if __name__ == "__main__":
    demo.launch()