Murillex commited on
Commit
a3a2558
·
verified ·
1 Parent(s): fd9296a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -35
app.py CHANGED
@@ -1,26 +1,26 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
 
 
 
4
 
5
- # Modelos A, B e Árbitro (modelo C)
6
- model_a_name = "google/flan-t5-base"
7
  model_b_name = "google/mt5-small"
8
- judge_model_name = "google/flan-t5-base"
9
 
10
- # Carrega modelos e tokenizers
11
  tokenizer_a = AutoTokenizer.from_pretrained(model_a_name)
12
- model_a = AutoModelForSeq2SeqLM.from_pretrained(model_a_name)
13
-
14
- from transformers import MT5Tokenizer
15
  tokenizer_b = MT5Tokenizer.from_pretrained(model_b_name, use_fast=False)
16
- model_b = AutoModelForSeq2SeqLM.from_pretrained(model_b_name)
17
 
18
- tokenizer_j = AutoTokenizer.from_pretrained(judge_model_name)
19
- model_j = AutoModelForSeq2SeqLM.from_pretrained(judge_model_name)
 
 
20
 
21
  def gerar_resposta(model, tokenizer, prompt):
22
  prompt_instruido = f"Question: {prompt}\nAnswer:"
23
- inputs = tokenizer(prompt_instruido, return_tensors="pt")
24
  with torch.no_grad():
25
  outputs = model.generate(
26
  **inputs,
@@ -33,12 +33,9 @@ def gerar_resposta(model, tokenizer, prompt):
33
  return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
34
 
35
  def limpar_resposta(resposta):
36
- if "<extra_id" in resposta.lower():
37
  return ""
38
- return resposta
39
-
40
- resp_a = limpar_resposta(gerar_resposta(model_a, tokenizer_a, prompt))
41
- resp_b = limpar_resposta(gerar_resposta(model_b, tokenizer_b, prompt))
42
 
43
  def julgar_respostas(prompt, resp_a, resp_b):
44
  prompt_julgamento = (
@@ -52,26 +49,38 @@ def julgar_respostas(prompt, resp_a, resp_b):
52
  return "B"
53
  return "A"
54
 
55
- def processar(prompt):
 
56
  resp_a = gerar_resposta(model_a, tokenizer_a, prompt)
57
  resp_b = gerar_resposta(model_b, tokenizer_b, prompt)
 
 
 
 
 
 
 
 
 
58
  melhor = julgar_respostas(prompt, resp_a, resp_b)
 
59
 
60
- final = resp_a if melhor == "A" else resp_b
61
- saida = (
62
- f"🟡 Prompt: {prompt}\n\n"
63
- f"🔹 Resposta A (Flan-T5):\n{resp_a}\n\n"
64
- f"🔸 Resposta B (mT5):\n{resp_b}\n\n"
65
- f"✅ Melhor resposta: Modelo {melhor} selecionado pelo árbitro.\n\n"
66
- f"💬 Resposta final:\n{final}"
67
- )
68
- return saida
 
 
 
 
69
 
70
  # Interface Gradio
71
- gr.Interface(
72
- fn=processar,
73
- inputs=gr.Textbox(lines=2, placeholder="Digite sua pergunta...", label="Prompt"),
74
- outputs=gr.Textbox(label="Resposta Final"),
75
- title="Chatbot em Cascata com Árbitro",
76
- description="Dois modelos geram respostas e um terceiro escolhe a melhor. Rodando localmente sem API externa."
77
- ).launch()
 
 
 
1
  import torch
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from transformers import MT5Tokenizer, MT5ForConditionalGeneration
4
+ import gradio as gr
5
 
6
+ # Modelos
7
+ model_a_name = "google/flan-t5-small"
8
  model_b_name = "google/mt5-small"
9
+ model_j_name = "google/flan-t5-small" # Árbitro
10
 
11
+ # Tokenizers
12
  tokenizer_a = AutoTokenizer.from_pretrained(model_a_name)
 
 
 
13
  tokenizer_b = MT5Tokenizer.from_pretrained(model_b_name, use_fast=False)
14
+ tokenizer_j = AutoTokenizer.from_pretrained(model_j_name)
15
 
16
+ # Modelos carregados
17
+ model_a = AutoModelForSeq2SeqLM.from_pretrained(model_a_name)
18
+ model_b = MT5ForConditionalGeneration.from_pretrained(model_b_name)
19
+ model_j = AutoModelForSeq2SeqLM.from_pretrained(model_j_name)
20
 
21
  def gerar_resposta(model, tokenizer, prompt):
22
  prompt_instruido = f"Question: {prompt}\nAnswer:"
23
+ inputs = tokenizer(prompt_instruido, return_tensors="pt", padding=True, truncation=True)
24
  with torch.no_grad():
25
  outputs = model.generate(
26
  **inputs,
 
33
  return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
34
 
35
  def limpar_resposta(resposta):
36
+ if "<extra_id" in resposta.lower() or resposta.strip() == "":
37
  return ""
38
+ return resposta.strip()
 
 
 
39
 
40
  def julgar_respostas(prompt, resp_a, resp_b):
41
  prompt_julgamento = (
 
49
  return "B"
50
  return "A"
51
 
52
+ def responder(prompt):
53
+ # Geração
54
  resp_a = gerar_resposta(model_a, tokenizer_a, prompt)
55
  resp_b = gerar_resposta(model_b, tokenizer_b, prompt)
56
+
57
+ # Limpeza
58
+ resp_a = limpar_resposta(resp_a)
59
+ resp_b = limpar_resposta(resp_b)
60
+
61
+ # Se ambas estão vazias
62
+ if not resp_a and not resp_b:
63
+ return "⚠️ Nenhuma resposta válida foi gerada."
64
+
65
  melhor = julgar_respostas(prompt, resp_a, resp_b)
66
+ resposta_final = resp_a if melhor == "A" else resp_b
67
 
68
+ return f"""🟡 Prompt: {prompt}
69
+
70
+ 🔹 Resposta A (Flan-T5):
71
+ {resp_a or '[Resposta inválida]'}
72
+
73
+ 🔸 Resposta B (mT5):
74
+ {resp_b or '[Resposta inválida]'}
75
+
76
+ Melhor resposta: Modelo {melhor} selecionado pelo árbitro.
77
+
78
+ 💬 Resposta final:
79
+ {resposta_final or '[Nenhuma resposta válida]'}
80
+ """
81
 
82
  # Interface Gradio
83
+ iface = gr.Interface(fn=responder, inputs="text", outputs="text", title="Chatbot em Cascata com Árbitro")
84
+
85
+ if __name__ == "__main__":
86
+ iface.launch()