zavavan commited on
Commit
0999443
·
verified ·
1 Parent(s): a91bdd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -25
app.py CHANGED
@@ -1,39 +1,54 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
-
4
- from Gradio_UI import GradioUI
5
-
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Load your fine-tuned model from Hugging Face Hub
9
- model = pipeline("text2text-generation", model='unica/CLiMA') # Replace with your actual model repo name
10
-
11
 
12
- # Define your prompt template (customize as needed)
13
- def format_prompt(user_input):
14
- return f"Identify causal relations in the following clinical narrative:\n\n{user_input}\n\nCausal relations:" # Modify if your model uses a different template
15
 
16
- # Define prediction function
17
- def generate_relations(text):
18
- prompt = format_prompt(text)
19
- result = model(prompt, max_length=512, do_sample=False)
20
- return result[0]['generated_text']
21
 
22
-
 
 
 
 
 
 
23
 
24
- # Gradio interface
25
  demo = gr.Interface(
26
  fn=generate_relations,
27
- inputs=gr.Textbox(lines=10, label="Clinical Note or Drug Review Text"),
 
 
 
 
28
  outputs=gr.Textbox(label="Extracted Causal Relations"),
29
  title="Causal Relation Extractor with MedLlama",
30
- description="Paste your clinical note or drug review. This AI agent extracts drug-condition or symptom causal relations using a fine-tuned LLM.",
31
  examples=[
32
- ["Patient reported severe headaches after starting amitriptyline."],
33
- ["Lisinopril helped reduce the patient's blood pressure but caused persistent cough."],
34
- ["After using Metformin, the patient experienced gastrointestinal discomfort."]
35
  ]
36
  )
37
 
38
- # Launch the app
39
- demo.launch()
 
1
+ # File: app.py
 
 
 
 
2
 
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from peft import PeftModel, PeftConfig
6
+ import torch
7
+
8
+ # Load PEFT adapter configuration
9
+ peft_config = PeftConfig.from_pretrained("unica/CLiMA")
10
+
11
+ # Load base model
12
+ base_model = AutoModelForCausalLM.from_pretrained(
13
+ peft_config.base_model_name_or_path,
14
+ torch_dtype=torch.float16,
15
+ device_map="auto"
16
+ )
17
 
18
+ # Load adapter weights
19
+ model = PeftModel.from_pretrained(base_model, "unica/CLiMA")
 
20
 
21
+ # Load tokenizer
22
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
 
23
 
24
+ # Format prompt
25
+ def format_prompt(user_input, entity1, entity2):
26
+ return f"Identify causal relations in the following clinical narrative:\n\n{user_input}\n\nEntity 1: {entity1}\nEntity 2: {entity2}\n\nCausal relations:"
 
 
27
 
28
+ # Prediction function
29
+ def generate_relations(text, entity1, entity2):
30
+ prompt = format_prompt(text, entity1, entity2)
31
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
32
+ outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
33
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ return response[len(prompt):].strip() # remove prompt from output if echoed
35
 
36
+ # Gradio UI
37
  demo = gr.Interface(
38
  fn=generate_relations,
39
+ inputs=[
40
+ gr.Textbox(lines=10, label="Clinical Note or Drug Review Text"),
41
+ gr.Textbox(label="Entity 1 (e.g., Drug)"),
42
+ gr.Textbox(label="Entity 2 (e.g., Condition or Symptom)")
43
+ ],
44
  outputs=gr.Textbox(label="Extracted Causal Relations"),
45
  title="Causal Relation Extractor with MedLlama",
46
+ description="Paste your clinical note or drug review, and specify two target entities. This AI agent extracts drug-condition or symptom causal relations using a fine-tuned LLM adapter model.",
47
  examples=[
48
+ ["Patient reported severe headaches after starting amitriptyline.", "amitriptyline", "headaches"],
49
+ ["Lisinopril helped reduce the patient's blood pressure but caused persistent cough.", "Lisinopril", "cough"],
50
+ ["After using Metformin, the patient experienced gastrointestinal discomfort.", "Metformin", "gastrointestinal discomfort"]
51
  ]
52
  )
53
 
54
+ demo.launch()