Spaces:
Runtime error
Runtime error
File size: 2,446 Bytes
2496b46 958630f 99e956b 2496b46 72bb8ba 2496b46 72bb8ba 2496b46 5052a13 2496b46 99e956b 2496b46 362fc9d 2496b46 2677dd7 baad4ce 2496b46 2677dd7 94d44d1 2496b46 cbb2710 94d44d1 2496b46 baad4ce 2496b46 72bb8ba 2496b46 2e5d958 72bb8ba 2496b46 2e5d958 72bb8ba 2e5d958 72bb8ba 2e5d958 72bb8ba 2e5d958 72bb8ba 2496b46 2e5d958 72bb8ba 2e5d958 72bb8ba 2e5d958 2496b46 2e5d958 2496b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from datasets import load_dataset
from peft import PeftModel
import os
title = "Gemma-2b SciQ"
description = """
Gemma-2b fine-tuned on SciQ
"""
article = "GitHub repository: https://github.com/P-Zande/nlp-team-4"
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
base_model = AutoModelForCausalLM.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
model = PeftModel.from_pretrained(base_model, "./")
model = model.merge_and_unload()
dataset = load_dataset("allenai/sciq")
random_test_samples = dataset["test"].select(range(5))
examples = []
for row in random_test_samples:
examples.append([row['support'], ""])
examples.append([row['support'], row['correct_answer']])
def predict(context = "", answer = ""):
formatted = context.replace('\n', ' ') + "\n"
if answer != "":
formatted = context.replace('\n', ' ') + "\n" + answer.replace('\n', ' ') + "\n"
inputs = tokenizer(formatted, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
split_outputs = decoded_outputs.split("\n")
if len(split_outputs) == 6:
return (
split_outputs[0],
split_outputs[1],
split_outputs[2],
split_outputs[3],
split_outputs[4],
split_outputs[5],
)
return ("ERROR: " + decoded_outputs, None, None, None, None, None)
support_gr = gr.TextArea(
label="Context",
value="Bananas are yellow and curved."
)
answer_gr = gr.Text(
label="Answer (optional)",
value="yellow"
)
context_output_gr = gr.Text(
label="Context"
)
answer_output_gr = gr.Text(
label="Answer"
)
question_output_gr = gr.Text(
label="Question"
)
distractor1_output_gr = gr.Text(
label="Distractor 1"
)
distractor2_output_gr = gr.Text(
label="Distractor 2"
)
distractor3_output_gr = gr.Text(
label="Distractor 3"
)
gr.Interface(
fn=predict,
inputs=[support_gr, answer_gr],
outputs=[context_output_gr, answer_output_gr, question_output_gr, distractor1_output_gr, distractor2_output_gr, distractor3_output_gr],
title=title,
description=description,
article=article,
examples=examples,
).launch() |