raft-qa-space / app.py
S-Dreamer's picture
Update app.py
7b5db99 verified
raw
history blame
1.49 kB
import gradio as gr
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch
import torch.nn.functional as F
# ←–– swap in a real QA model
MODEL_NAME = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
def answer_question(context, question):
inputs = tokenizer(
question,
context,
return_tensors="pt",
truncation=True,
max_length=512,
stride=128,
return_overflowing_tokens=True
)
with torch.no_grad():
outputs = model(**inputs)
start_idx = torch.argmax(F.softmax(outputs.start_logits, dim=-1))
end_idx = torch.argmax(F.softmax(outputs.end_logits, dim=-1)) + 1
answer = tokenizer.decode(
inputs["input_ids"][0][start_idx:end_idx],
skip_special_tokens=True
)
return answer or "No answer found."
with gr.Blocks() as demo:
gr.Markdown("# 🤖 RAFT: Retrieval-Augmented Fine-Tuning for QA")
gr.Markdown("Ask a question based on the provided context…")
with gr.Row():
context_input = gr.Textbox(lines=5, label="Context")
question_input = gr.Textbox(lines=2, label="Question")
answer_output = gr.Textbox(label="Answer", interactive=False)
gr.Button("Generate Answer").click(
answer_question,
inputs=[context_input, question_input],
outputs=answer_output
)
demo.launch()