gemma-2b-sciq / app.py
Darwinkel's picture
Update app.py
5052a13 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from datasets import load_dataset
from peft import PeftModel
import os
title = "Gemma-2b SciQ"
description = """
Gemma-2b fine-tuned on SciQ
"""
article = "GitHub repository: https://github.com/P-Zande/nlp-team-4"
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
base_model = AutoModelForCausalLM.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
model = PeftModel.from_pretrained(base_model, "./")
model = model.merge_and_unload()
dataset = load_dataset("allenai/sciq")
random_test_samples = dataset["test"].select(range(5))
examples = []
for row in random_test_samples:
examples.append([row['support'], ""])
examples.append([row['support'], row['correct_answer']])
def predict(context = "", answer = ""):
formatted = context.replace('\n', ' ') + "\n"
if answer != "":
formatted = context.replace('\n', ' ') + "\n" + answer.replace('\n', ' ') + "\n"
inputs = tokenizer(formatted, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
split_outputs = decoded_outputs.split("\n")
if len(split_outputs) == 6:
return (
split_outputs[0],
split_outputs[1],
split_outputs[2],
split_outputs[3],
split_outputs[4],
split_outputs[5],
)
return ("ERROR: " + decoded_outputs, None, None, None, None, None)
support_gr = gr.TextArea(
label="Context",
value="Bananas are yellow and curved."
)
answer_gr = gr.Text(
label="Answer (optional)",
value="yellow"
)
context_output_gr = gr.Text(
label="Context"
)
answer_output_gr = gr.Text(
label="Answer"
)
question_output_gr = gr.Text(
label="Question"
)
distractor1_output_gr = gr.Text(
label="Distractor 1"
)
distractor2_output_gr = gr.Text(
label="Distractor 2"
)
distractor3_output_gr = gr.Text(
label="Distractor 3"
)
gr.Interface(
fn=predict,
inputs=[support_gr, answer_gr],
outputs=[context_output_gr, answer_output_gr, question_output_gr, distractor1_output_gr, distractor2_output_gr, distractor3_output_gr],
title=title,
description=description,
article=article,
examples=examples,
).launch()