Murillex commited on
Commit
d458b05
·
verified ·
1 Parent(s): b9c70bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -1,27 +1,29 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
3
- import torch
4
 
5
- # Modelos de resposta A e B
6
- generator_a = pipeline("text-generation", model="tiiuae/falcon-rw-1b", tokenizer="tiiuae/falcon-rw-1b", device=-1)
7
- generator_b = pipeline("text2text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base", device=-1)
8
 
9
- # Modelo árbitro
10
- tokenizer_c = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
11
- model_c = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
12
- classifier = pipeline("text-classification", model=model_c, tokenizer=tokenizer_c)
13
 
14
  def judge(prompt, response_a, response_b):
15
- # Avaliar qual resposta tem mais "entailment" com o prompt
16
- result_a = classifier(f"{prompt} </s> {response_a}")[0]
17
- result_b = classifier(f"{prompt} </s> {response_b}")[0]
18
- return response_a if result_a['score'] > result_b['score'] else response_b
 
 
 
 
19
 
20
  def chat(prompt):
21
- response_a = generator_a(prompt, max_length=100, do_sample=True)[0]['generated_text']
22
- response_b = generator_b(prompt, max_length=100, do_sample=True)[0]['generated_text']
23
  best = judge(prompt, response_a, response_b)
24
  return f"Resposta A:\n{response_a}\n\nResposta B:\n{response_b}\n\n✅ Melhor Resposta:\n{best}"
25
 
26
- iface = gr.Interface(fn=chat, inputs=gr.Textbox(label="Seu prompt"), outputs=gr.Textbox(label="Resposta escolhida"), title="Chatbot em Cascata")
27
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ from sentence_transformers import SentenceTransformer, util
4
 
5
+ # Modelos A e B
6
+ generator_a = pipeline("text2text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base")
7
+ generator_b = pipeline("text2text-generation", model="declare-lab/flan-alpaca-base", tokenizer="declare-lab/flan-alpaca-base")
8
 
9
+ # Árbitro baseado em similaridade semântica
10
+ similarity_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
 
 
11
 
12
  def judge(prompt, response_a, response_b):
13
+ emb_prompt = similarity_model.encode(prompt, convert_to_tensor=True)
14
+ emb_a = similarity_model.encode(response_a, convert_to_tensor=True)
15
+ emb_b = similarity_model.encode(response_b, convert_to_tensor=True)
16
+
17
+ score_a = util.pytorch_cos_sim(emb_prompt, emb_a).item()
18
+ score_b = util.pytorch_cos_sim(emb_prompt, emb_b).item()
19
+
20
+ return response_a if score_a >= score_b else response_b
21
 
22
  def chat(prompt):
23
+ response_a = generator_a(prompt, max_length=100)[0]['generated_text']
24
+ response_b = generator_b(prompt, max_length=100)[0]['generated_text']
25
  best = judge(prompt, response_a, response_b)
26
  return f"Resposta A:\n{response_a}\n\nResposta B:\n{response_b}\n\n✅ Melhor Resposta:\n{best}"
27
 
28
+ iface = gr.Interface(fn=chat, inputs=gr.Textbox(label="Pergunta"), outputs=gr.Textbox(label="Resposta Escolhida"), title="Chatbot em Cascata")
29
  iface.launch()