S-Dreamer commited on
Commit
7b5db99
·
verified ·
1 Parent(s): 4b144ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -3,39 +3,45 @@ from transformers import AutoModelForQuestionAnswering, AutoTokenizer
3
  import torch
4
  import torch.nn.functional as F
5
 
6
- # Load model and tokenizer
7
- MODEL_NAME = "S-Dreamer/raft-qa-space"
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
10
 
11
  def answer_question(context, question):
12
  inputs = tokenizer(
13
- question, context, return_tensors="pt", truncation=True, max_length=512, stride=128, return_overflowing_tokens=True
 
 
 
 
 
 
14
  )
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
-
18
- start_probs = F.softmax(outputs.start_logits, dim=-1)
19
- end_probs = F.softmax(outputs.end_logits, dim=-1)
20
- start_idx = torch.argmax(start_probs)
21
- end_idx = torch.argmax(end_probs) + 1
22
 
23
- answer = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx], skip_special_tokens=True)
24
-
25
- return answer if answer.strip() else "No answer found."
 
 
 
 
 
26
 
27
- # Define UI
28
  with gr.Blocks() as demo:
29
  gr.Markdown("# 🤖 RAFT: Retrieval-Augmented Fine-Tuning for QA")
30
- gr.Markdown("Ask a question based on the provided context and see how RAFT improves response accuracy!")
31
-
32
  with gr.Row():
33
- context_input = gr.Textbox(lines=5, label="Context", placeholder="Enter background text here...")
34
- question_input = gr.Textbox(lines=2, label="Question", placeholder="What is the main idea?")
35
-
36
  answer_output = gr.Textbox(label="Answer", interactive=False)
37
-
38
- submit_btn = gr.Button("Generate Answer")
39
- submit_btn.click(answer_question, inputs=[context_input, question_input], outputs=answer_output)
 
 
40
 
41
  demo.launch()
 
3
  import torch
4
  import torch.nn.functional as F
5
 
6
+ # ←–– swap in a real QA model
7
+ MODEL_NAME = "deepset/roberta-base-squad2"
8
+
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
11
 
12
  def answer_question(context, question):
13
  inputs = tokenizer(
14
+ question,
15
+ context,
16
+ return_tensors="pt",
17
+ truncation=True,
18
+ max_length=512,
19
+ stride=128,
20
+ return_overflowing_tokens=True
21
  )
22
  with torch.no_grad():
23
  outputs = model(**inputs)
 
 
 
 
 
24
 
25
+ start_idx = torch.argmax(F.softmax(outputs.start_logits, dim=-1))
26
+ end_idx = torch.argmax(F.softmax(outputs.end_logits, dim=-1)) + 1
27
+
28
+ answer = tokenizer.decode(
29
+ inputs["input_ids"][0][start_idx:end_idx],
30
+ skip_special_tokens=True
31
+ )
32
+ return answer or "No answer found."
33
 
 
34
  with gr.Blocks() as demo:
35
  gr.Markdown("# 🤖 RAFT: Retrieval-Augmented Fine-Tuning for QA")
36
+ gr.Markdown("Ask a question based on the provided context")
 
37
  with gr.Row():
38
+ context_input = gr.Textbox(lines=5, label="Context")
39
+ question_input = gr.Textbox(lines=2, label="Question")
 
40
  answer_output = gr.Textbox(label="Answer", interactive=False)
41
+ gr.Button("Generate Answer").click(
42
+ answer_question,
43
+ inputs=[context_input, question_input],
44
+ outputs=answer_output
45
+ )
46
 
47
  demo.launch()