Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import random | |
from datasets import load_dataset | |
from peft import PeftModel | |
import os | |
title = "Gemma-2b SciQ" | |
description = """ | |
Gemma-2b fine-tuned on SciQ | |
""" | |
article = "GitHub repository: https://github.com/P-Zande/nlp-team-4" | |
model_id = "google/gemma-2b" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
base_model = AutoModelForCausalLM.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
model = PeftModel.from_pretrained(base_model, "./") | |
model = model.merge_and_unload() | |
dataset = load_dataset("allenai/sciq") | |
random_test_samples = dataset["test"].select(range(5)) | |
examples = [] | |
for row in random_test_samples: | |
examples.append([row['support'], ""]) | |
examples.append([row['support'], row['correct_answer']]) | |
def predict(context = "", answer = ""): | |
formatted = context.replace('\n', ' ') + "\n" | |
if answer != "": | |
formatted = context.replace('\n', ' ') + "\n" + answer.replace('\n', ' ') + "\n" | |
inputs = tokenizer(formatted, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=100) | |
decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
split_outputs = decoded_outputs.split("\n") | |
if len(split_outputs) == 6: | |
return ( | |
split_outputs[0], | |
split_outputs[1], | |
split_outputs[2], | |
split_outputs[3], | |
split_outputs[4], | |
split_outputs[5], | |
) | |
return ("ERROR: " + decoded_outputs, None, None, None, None, None) | |
support_gr = gr.TextArea( | |
label="Context", | |
value="Bananas are yellow and curved." | |
) | |
answer_gr = gr.Text( | |
label="Answer (optional)", | |
value="yellow" | |
) | |
context_output_gr = gr.Text( | |
label="Context" | |
) | |
answer_output_gr = gr.Text( | |
label="Answer" | |
) | |
question_output_gr = gr.Text( | |
label="Question" | |
) | |
distractor1_output_gr = gr.Text( | |
label="Distractor 1" | |
) | |
distractor2_output_gr = gr.Text( | |
label="Distractor 2" | |
) | |
distractor3_output_gr = gr.Text( | |
label="Distractor 3" | |
) | |
gr.Interface( | |
fn=predict, | |
inputs=[support_gr, answer_gr], | |
outputs=[context_output_gr, answer_output_gr, question_output_gr, distractor1_output_gr, distractor2_output_gr, distractor3_output_gr], | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
).launch() |