File size: 5,480 Bytes
0999443 8c01ffb 0999443 78ff932 0999443 75ed445 0999443 78ff932 0999443 78ff932 0999443 8c01ffb 0999443 8c01ffb 0999443 8c01ffb dabb5c1 0999443 dabb5c1 8c01ffb 0999443 75ed445 0999443 78ff932 0999443 75ed445 0052128 d7f4b9f 0052128 d7f4b9f 0052128 d7f4b9f 0052128 d7f4b9f 0052128 d7f4b9f 0052128 d7f4b9f 0052128 d7f4b9f 0052128 d7f4b9f 0052128 d7f4b9f 75ed445 8fe992b 0999443 98bfd8d 0999443 98bfd8d 0999443 98bfd8d 8676cfb db878c1 98bfd8d 9b5b26a 0999443 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# File: app.py
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import torch
import regex as re
# Load PEFT adapter configuration
peft_config = PeftConfig.from_pretrained("unica/CLiMA")
# BitsAndBytes 4-bit config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Most efficient for LLMs
bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 or float16 depending on your GPU
bnb_4bit_use_double_quant=True
)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
quantization_config=bnb_config,
device_map="auto"
)
# Load adapter weights
model = PeftModel.from_pretrained(base_model, "unica/CLiMA")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
prompt_instruction_drug_reviews = f"""Given a drug review enclosed in triple quotes and a pair of entities E1 corresponding to the drug name and E2 corresponding to the treated condition, classify the relation holding between E1 and E2.
The relations are identified with 9 labels from 0 to 8. The meaning of the labels is the following:
0 means that E1 causes E2
1 means that E2 causes E1
2 means that E1 enables E2
3 means that E2 enables E1
4 means that E1 prevents E2
5 means that E2 prevents E1
6 means that E1 hinders E2
7 means that E2 hinders E1
8 means that E1 and E2 are in a relation different than any of the previous ones.
Given X the label that you predicted, for the output use the format LABEL: X
"""
# Format prompt
def format_prompt(user_input, entity1, entity2):
#return f"Identify causal relations in the following clinical narrative:\n\n{user_input}\n\nEntity 1: {entity1}\nEntity 2: {entity2}\n\nCausal relations:"
text = user_input
prompt_text = f"Text:'''{text}'''"
e1 = entity1
e2 = entity2
prompt_entities = f"\nEntities: E1: '''{e1}''', E2: '''{e2}'''"
full_prompt = f"<USER> {prompt_instruction_drug_reviews} {prompt_text} {prompt_entities} <ASSISTANT>"
return full_prompt
# Prediction function
def generate_relations(text, entity1, entity2):
answer_label_regex_pattern = re.compile(r'LABEL:?\s?(\d+)')
prompt = format_prompt(text, entity1, entity2)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
modelOut = response[len(prompt):].strip() # remove prompt from output if echoed
answer_match = answer_label_regex_pattern.search(modelOut)
if answer_match:
if answer_match.group(1)=='0':
return f"""'{entity1}' causes '{entity2}'"""
elif answer_match.group(1)=='1':
return f"""'{entity2}' causes '{entity1}'"""
elif answer_match.group(1)=='2':
return f"""'{entity1}' enables '{entity2}'"""
elif answer_match.group(1)=='3':
return f"""'{entity2}' enables '{entity1}'"""
elif answer_match.group(1)=='4':
return f"""'{entity1}' prevents '{entity2}'"""
elif answer_match.group(1)=='5':
return f"""'{entity2}' prevents '{entity1}'"""
elif answer_match.group(1)=='6':
return f"""'{entity1}' hinders '{entity2}'"""
elif answer_match.group(1)=='7':
return f"""'{entity2}' hinders '{entity1}'"""
elif answer_match.group(1)=='8':
return f"""No causal relation between '{entity1}' and '{entity2}'"""
else:
return 'No causal relation could be extracted'
# Gradio UI
demo = gr.Interface(
fn=generate_relations,
inputs=[
gr.Textbox(lines=10, label="Clinical Note or Drug Review Text"),
gr.Textbox(label="Entity 1 (e.g., Drug)"),
gr.Textbox(label="Entity 2 (e.g., Condition or Symptom)")
],
outputs=gr.Textbox(label="Extracted Causal Relations"),
title="Causal Relation Extractor with MedLlama",
description="Paste your clinical note or drug review, and specify two target entities. This AI agent extracts drug-condition or symptom causal relations using a fine-tuned LLM adapter model.",
examples=[
["Odynophagia: Was presumed due to mucositis from recent chemotherapy.", "chemotherapy", "mucositis"],
["patient's wife noticed erythema on patient's face. On [**3-27**]the visiting nurse [**First Name (Titles) 8706**][**Last Name (Titles)11282**]of a rash on his arms as well. The patient was noted to be febrile and was admitted to the [**Company 191**] Firm. In the EW, patient's Dilantin was discontinued and he was given Tegretol instead.", "Dilantin", "erythema on patient's face"],
["i had a urinary tract infection so bad that when i pee it smells but when i started taking ciprofloxacin it worked it’s a good medicine for a urinary tract infections.","ciprofloxacin","urinary tract infection"],
["when i first started using ziana, i only had acne in between my eyebrows, chin, and the nose area. my acne worsened while using it and then it got better. but after about 4 months of using it, it became ineffective. so i now have acne between my eyebrows, chin, cheeks, forehead, and the nose area. its great at first but after a while it made my face even worse than before i used the product.","ziana","acne"]
]
)
demo.launch()
|