|
import gradio as gr
|
|
import torch
|
|
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
|
|
from torch import nn
|
|
import re
|
|
|
|
|
|
def paragraph_leveling(text):
|
|
model_name = "./trained_model/fine_tunning_encoder_v2"
|
|
model = AutoModel.from_pretrained(model_name)
|
|
model.to('cuda')
|
|
tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, target_size=3, input_size=768):
|
|
super(MLP, self).__init__()
|
|
self.num_classes = target_size
|
|
self.input_size = input_size
|
|
self.fc1 = nn.Linear(input_size, target_size)
|
|
|
|
def forward(self, x):
|
|
out = self.fc1(x)
|
|
return out
|
|
|
|
classifier = MLP(target_size=3, input_size=768)
|
|
classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier'))
|
|
classifier.cuda()
|
|
classifier.eval()
|
|
|
|
output_list = []
|
|
text_list = text.split(".")
|
|
result = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_list.append(("\n", None))
|
|
|
|
for idx_sentence in text_list:
|
|
train_encoding = tokenizer(
|
|
idx_sentence,
|
|
return_tensors='pt',
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=120)
|
|
output = model(**train_encoding.to('cuda'))
|
|
output = classifier(output[1])
|
|
output = output[0]
|
|
|
|
if output.argmax(-1) == 0:
|
|
output_list.append((idx_sentence, 'abnormal'))
|
|
result.append(0)
|
|
elif output.argmax(-1) == 1:
|
|
output_list.append((idx_sentence, 'normal'))
|
|
result.append(1)
|
|
else:
|
|
output_list.append((idx_sentence, 'not much information'))
|
|
result.append(2)
|
|
|
|
output_list.append(('\n', None))
|
|
if 0 in result:
|
|
output_list.append(('FINAL LABEL: ', None))
|
|
output_list.append(('ABNORMAL', 'abnormal'))
|
|
|
|
else:
|
|
output_list.append(('FINAL LABEL: ', None))
|
|
output_list.append(('NORMAL', 'normal'))
|
|
|
|
return output_list
|
|
|
|
|
|
demo = gr.Interface(
|
|
paragraph_leveling,
|
|
[
|
|
gr.Textbox(
|
|
label="Medical Report",
|
|
info="You can put any types of medical report",
|
|
lines=20,
|
|
value=" ",
|
|
),
|
|
],
|
|
gr.HighlightedText(
|
|
label="labeling",
|
|
show_legend = True,
|
|
show_label = True,
|
|
color_map={"abnormal": "violet", "normal": "lightgreen", "not much information": "lightgray"}),
|
|
theme=gr.themes.Base()
|
|
)
|
|
if __name__ == "__main__":
|
|
demo.launch(share=True)
|
|
|
|
|