Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
# ←–– swap in a real QA model | |
MODEL_NAME = "deepset/roberta-base-squad2" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) | |
def answer_question(context, question): | |
inputs = tokenizer( | |
question, | |
context, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
stride=128, | |
return_overflowing_tokens=True | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
start_idx = torch.argmax(F.softmax(outputs.start_logits, dim=-1)) | |
end_idx = torch.argmax(F.softmax(outputs.end_logits, dim=-1)) + 1 | |
answer = tokenizer.decode( | |
inputs["input_ids"][0][start_idx:end_idx], | |
skip_special_tokens=True | |
) | |
return answer or "No answer found." | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🤖 RAFT: Retrieval-Augmented Fine-Tuning for QA") | |
gr.Markdown("Ask a question based on the provided context…") | |
with gr.Row(): | |
context_input = gr.Textbox(lines=5, label="Context") | |
question_input = gr.Textbox(lines=2, label="Question") | |
answer_output = gr.Textbox(label="Answer", interactive=False) | |
gr.Button("Generate Answer").click( | |
answer_question, | |
inputs=[context_input, question_input], | |
outputs=answer_output | |
) | |
demo.launch() | |