import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import random from datasets import load_dataset from peft import PeftModel import os title = "Gemma-2b SciQ" description = """ Gemma-2b fine-tuned on SciQ """ article = "GitHub repository: https://github.com/P-Zande/nlp-team-4" model_id = "google/gemma-2b" tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) base_model = AutoModelForCausalLM.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) model = PeftModel.from_pretrained(base_model, "./") model = model.merge_and_unload() dataset = load_dataset("allenai/sciq") random_test_samples = dataset["test"].select(range(5)) examples = [] for row in random_test_samples: examples.append([row['support'], ""]) examples.append([row['support'], row['correct_answer']]) def predict(context = "", answer = ""): formatted = context.replace('\n', ' ') + "\n" if answer != "": formatted = context.replace('\n', ' ') + "\n" + answer.replace('\n', ' ') + "\n" inputs = tokenizer(formatted, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=100) decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) split_outputs = decoded_outputs.split("\n") if len(split_outputs) == 6: return ( split_outputs[0], split_outputs[1], split_outputs[2], split_outputs[3], split_outputs[4], split_outputs[5], ) return ("ERROR: " + decoded_outputs, None, None, None, None, None) support_gr = gr.TextArea( label="Context", value="Bananas are yellow and curved." ) answer_gr = gr.Text( label="Answer (optional)", value="yellow" ) context_output_gr = gr.Text( label="Context" ) answer_output_gr = gr.Text( label="Answer" ) question_output_gr = gr.Text( label="Question" ) distractor1_output_gr = gr.Text( label="Distractor 1" ) distractor2_output_gr = gr.Text( label="Distractor 2" ) distractor3_output_gr = gr.Text( label="Distractor 3" ) gr.Interface( fn=predict, inputs=[support_gr, answer_gr], outputs=[context_output_gr, answer_output_gr, question_output_gr, distractor1_output_gr, distractor2_output_gr, distractor3_output_gr], title=title, description=description, article=article, examples=examples, ).launch()