zavavan commited on
Commit
75ed445
·
verified ·
1 Parent(s): dabb5c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
5
  from peft import PeftModel, PeftConfig
6
  import torch
 
7
 
8
  # Load PEFT adapter configuration
9
  peft_config = PeftConfig.from_pretrained("unica/CLiMA")
@@ -57,11 +58,38 @@ def format_prompt(user_input, entity1, entity2):
57
 
58
  # Prediction function
59
  def generate_relations(text, entity1, entity2):
 
 
 
60
  prompt = format_prompt(text, entity1, entity2)
61
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
62
  outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
63
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
- return response[len(prompt):].strip() # remove prompt from output if echoed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Gradio UI
67
  demo = gr.Interface(
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
5
  from peft import PeftModel, PeftConfig
6
  import torch
7
+ import regex as re
8
 
9
  # Load PEFT adapter configuration
10
  peft_config = PeftConfig.from_pretrained("unica/CLiMA")
 
58
 
59
  # Prediction function
60
  def generate_relations(text, entity1, entity2):
61
+ answer_label_regex_pattern = re.compile(r'LABEL:?\s?(\d+)')
62
+
63
+
64
  prompt = format_prompt(text, entity1, entity2)
65
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
66
  outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
67
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
+ modelOut = response[len(prompt):].strip() # remove prompt from output if echoed
69
+ answer_match = answer_label_regex_pattern.search(modelOut)
70
+ if answer_match:
71
+ if answer_match.group(1)=='0'
72
+ return f"""{entity1} causes {entity2}"""
73
+ elif answer_match.group(1)=='1'
74
+ return f"""{entity2} causes {entity1}"""
75
+ elif answer_match.group(1)=='2'
76
+ return f"""{entity1} enables {entity2}"""
77
+ elif answer_match.group(1)=='3'
78
+ return f"""{entity2} enables {entity1}"""
79
+ elif answer_match.group(1)=='4'
80
+ return f"""{entity1} prevents {entity2}"""
81
+ elif answer_match.group(1)=='5'
82
+ return f"""{entity2} prevents {entity1}"""
83
+ elif answer_match.group(1)=='6'
84
+ return f"""{entity1} hinders {entity2}"""
85
+ elif answer_match.group(1)=='7'
86
+ return f"""{entity2} hinders {entity1}"""
87
+ elif answer_match.group(1)=='8'
88
+ return f"""No causal relation between {entity1} and {entity2}"""
89
+ else:
90
+ return 'No causal relation could be extracted'
91
+
92
+
93
 
94
  # Gradio UI
95
  demo = gr.Interface(