emr-distillation's picture
Create app.py
a8af44f verified
raw
history blame
2.86 kB
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Spaces
Posts
Docs
Pricing
Spaces:
junhyun01
/
Onpremise_LLM_Normal_Detection
private
App
Files
Community
1
Settings
Onpremise_LLM_Normal_Detection
/
app.py
junhyun01's picture
junhyun01
Update app.py
18d8c3a
VERIFIED
24 days ago
raw
history
blame
edit
delete
No virus
2.5 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
from torch import nn
import re
def paragraph_leveling(text):
model_name = "./trained_model/fine_tunning_encoder_v2"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
class MLP(nn.Module):
def __init__(self, target_size=3, input_size=768):
super(MLP, self).__init__()
self.num_classes = target_size
self.input_size = input_size
self.fc1 = nn.Linear(input_size, target_size)
def forward(self, x):
out = self.fc1(x)
return out
classifier = MLP(target_size=3, input_size=768)
classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu')))
classifier.eval()
output_list = []
text_list = text.split(".")
result = []
output_list.append(("\n", None))
for idx_sentence in text_list:
train_encoding = tokenizer(
idx_sentence,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=120)
output = model(**train_encoding)
output = classifier(output[1])
output = output[0]
if output.argmax(-1) == 0:
output_list.append((idx_sentence, 'abnormal'))
result.append(0)
elif output.argmax(-1) == 1:
output_list.append((idx_sentence, 'normal'))
result.append(1)
else:
output_list.append((idx_sentence, 'uncertain'))
result.append(2)
output_list.append(('\n', None))
if 0 in result:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('ABNORMAL', 'abnormal'))
else:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('NORMAL', 'normal'))
return output_list
demo = gr.Interface(
paragraph_leveling,
[
gr.Textbox(
label="Medical Report",
info="You can put any types of medical report",
lines=20,
value=" ",
),
],
gr.HighlightedText(
label="labeling",
show_legend = True,
show_label = True,
color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}),
theme=gr.themes.Base()
)
if __name__ == "__main__":
demo.launch()