File size: 2,446 Bytes
2496b46
 
 
 
958630f
99e956b
2496b46
72bb8ba
2496b46
72bb8ba
2496b46
 
5052a13
2496b46
99e956b
 
 
 
 
 
2496b46
 
 
362fc9d
2496b46
 
 
2677dd7
baad4ce
2496b46
 
2677dd7
94d44d1
2496b46
cbb2710
94d44d1
2496b46
 
 
 
 
 
 
 
baad4ce
 
 
 
 
 
 
 
2496b46
72bb8ba
2496b46
 
 
 
 
 
 
2e5d958
72bb8ba
2496b46
 
 
2e5d958
72bb8ba
2e5d958
 
72bb8ba
2e5d958
 
72bb8ba
2e5d958
 
72bb8ba
2496b46
2e5d958
72bb8ba
2e5d958
 
72bb8ba
2e5d958
 
2496b46
 
 
2e5d958
2496b46
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from datasets import load_dataset
from peft import PeftModel
import os

title = "Gemma-2b SciQ"
description = """
Gemma-2b fine-tuned on SciQ
"""

article = "GitHub repository: https://github.com/P-Zande/nlp-team-4"

model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
base_model = AutoModelForCausalLM.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))

model = PeftModel.from_pretrained(base_model, "./")
model = model.merge_and_unload()


dataset = load_dataset("allenai/sciq")
random_test_samples = dataset["test"].select(range(5))

examples = []
for row in random_test_samples:
    examples.append([row['support'], ""])
    examples.append([row['support'], row['correct_answer']])


def predict(context = "", answer = ""):
    formatted = context.replace('\n', ' ') + "\n"
    
    if answer != "":
        formatted = context.replace('\n', ' ') + "\n" + answer.replace('\n', ' ') + "\n"

        
    inputs = tokenizer(formatted, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=100)
    decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
    split_outputs = decoded_outputs.split("\n")
    
    if len(split_outputs) == 6:
        return (
            split_outputs[0],
            split_outputs[1],
            split_outputs[2],
            split_outputs[3],
            split_outputs[4],
            split_outputs[5],
        )
        
    return ("ERROR: " + decoded_outputs, None, None, None, None, None)
    

support_gr = gr.TextArea(
    label="Context",
    value="Bananas are yellow and curved."
)

answer_gr = gr.Text(
    label="Answer (optional)",
    value="yellow"
)

context_output_gr = gr.Text(
    label="Context"
)
answer_output_gr = gr.Text(
    label="Answer"
)
question_output_gr = gr.Text(
    label="Question"
)
distractor1_output_gr = gr.Text(
    label="Distractor 1"
)
distractor2_output_gr = gr.Text(
    label="Distractor 2"
)
distractor3_output_gr = gr.Text(
    label="Distractor 3"
)

gr.Interface(
    fn=predict,
    inputs=[support_gr, answer_gr],
    outputs=[context_output_gr, answer_output_gr, question_output_gr, distractor1_output_gr, distractor2_output_gr, distractor3_output_gr],
    title=title,
    description=description,
    article=article,
    examples=examples,
).launch()