Hugging Face's logo Hugging Face Search models, datasets, users... Models Datasets Spaces Posts Docs Pricing Spaces: junhyun01 / Onpremise_LLM_Normal_Detection private App Files Community 1 Settings Onpremise_LLM_Normal_Detection / app.py junhyun01's picture junhyun01 Update app.py 18d8c3a VERIFIED 24 days ago raw history blame edit delete No virus 2.5 kB import gradio as gr import torch from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel from torch import nn import re def paragraph_leveling(text): model_name = "./trained_model/fine_tunning_encoder_v2" model = AutoModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m') class MLP(nn.Module): def __init__(self, target_size=3, input_size=768): super(MLP, self).__init__() self.num_classes = target_size self.input_size = input_size self.fc1 = nn.Linear(input_size, target_size) def forward(self, x): out = self.fc1(x) return out classifier = MLP(target_size=3, input_size=768) classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu'))) classifier.eval() output_list = [] text_list = text.split(".") result = [] output_list.append(("\n", None)) for idx_sentence in text_list: train_encoding = tokenizer( idx_sentence, return_tensors='pt', padding='max_length', truncation=True, max_length=120) output = model(**train_encoding) output = classifier(output[1]) output = output[0] if output.argmax(-1) == 0: output_list.append((idx_sentence, 'abnormal')) result.append(0) elif output.argmax(-1) == 1: output_list.append((idx_sentence, 'normal')) result.append(1) else: output_list.append((idx_sentence, 'uncertain')) result.append(2) output_list.append(('\n', None)) if 0 in result: output_list.append(('FINAL LABEL: ', None)) output_list.append(('ABNORMAL', 'abnormal')) else: output_list.append(('FINAL LABEL: ', None)) output_list.append(('NORMAL', 'normal')) return output_list demo = gr.Interface( paragraph_leveling, [ gr.Textbox( label="Medical Report", info="You can put any types of medical report", lines=20, value=" ", ), ], gr.HighlightedText( label="labeling", show_legend = True, show_label = True, color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}), theme=gr.themes.Base() ) if __name__ == "__main__": demo.launch()