emr-distillation commited on
Commit
a8af44f
·
verified ·
1 Parent(s): 5a8e4e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Pricing
10
+
11
+
12
+
13
+ Spaces:
14
+
15
+ junhyun01
16
+ /
17
+ Onpremise_LLM_Normal_Detection
18
+
19
+ private
20
+
21
+ App
22
+ Files
23
+ Community
24
+ 1
25
+ Settings
26
+ Onpremise_LLM_Normal_Detection
27
+ /
28
+ app.py
29
+ junhyun01's picture
30
+ junhyun01
31
+ Update app.py
32
+ 18d8c3a
33
+ VERIFIED
34
+ 24 days ago
35
+ raw
36
+ history
37
+ blame
38
+ edit
39
+ delete
40
+ No virus
41
+ 2.5 kB
42
+ import gradio as gr
43
+ import torch
44
+ from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
45
+ from torch import nn
46
+ import re
47
+
48
+
49
+ def paragraph_leveling(text):
50
+ model_name = "./trained_model/fine_tunning_encoder_v2"
51
+ model = AutoModel.from_pretrained(model_name)
52
+ tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
53
+
54
+ class MLP(nn.Module):
55
+ def __init__(self, target_size=3, input_size=768):
56
+ super(MLP, self).__init__()
57
+ self.num_classes = target_size
58
+ self.input_size = input_size
59
+ self.fc1 = nn.Linear(input_size, target_size)
60
+
61
+ def forward(self, x):
62
+ out = self.fc1(x)
63
+ return out
64
+
65
+ classifier = MLP(target_size=3, input_size=768)
66
+ classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu')))
67
+ classifier.eval()
68
+
69
+ output_list = []
70
+ text_list = text.split(".")
71
+ result = []
72
+
73
+ output_list.append(("\n", None))
74
+
75
+ for idx_sentence in text_list:
76
+ train_encoding = tokenizer(
77
+ idx_sentence,
78
+ return_tensors='pt',
79
+ padding='max_length',
80
+ truncation=True,
81
+ max_length=120)
82
+ output = model(**train_encoding)
83
+ output = classifier(output[1])
84
+ output = output[0]
85
+
86
+ if output.argmax(-1) == 0:
87
+ output_list.append((idx_sentence, 'abnormal'))
88
+ result.append(0)
89
+ elif output.argmax(-1) == 1:
90
+ output_list.append((idx_sentence, 'normal'))
91
+ result.append(1)
92
+ else:
93
+ output_list.append((idx_sentence, 'uncertain'))
94
+ result.append(2)
95
+
96
+ output_list.append(('\n', None))
97
+ if 0 in result:
98
+ output_list.append(('FINAL LABEL: ', None))
99
+ output_list.append(('ABNORMAL', 'abnormal'))
100
+
101
+ else:
102
+ output_list.append(('FINAL LABEL: ', None))
103
+ output_list.append(('NORMAL', 'normal'))
104
+
105
+ return output_list
106
+
107
+
108
+ demo = gr.Interface(
109
+ paragraph_leveling,
110
+ [
111
+ gr.Textbox(
112
+ label="Medical Report",
113
+ info="You can put any types of medical report",
114
+ lines=20,
115
+ value=" ",
116
+ ),
117
+ ],
118
+ gr.HighlightedText(
119
+ label="labeling",
120
+ show_legend = True,
121
+ show_label = True,
122
+ color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}),
123
+ theme=gr.themes.Base()
124
+ )
125
+ if __name__ == "__main__":
126
+ demo.launch()
127
+