Michel-25's picture
Update app.py
0c4b567 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import warnings
warnings.filterwarnings("ignore")
# Cache global para modelos
model_cache = {}
def load_model(model_name):
"""Carrega modelo com cache para evitar recarregamentos"""
if model_name in model_cache:
return model_cache[model_name]
try:
if "blenderbot" in model_name.lower():
# Para modelos de conversação
pipe = pipeline(
"conversational",
model=model_name,
tokenizer=model_name,
device=-1, # CPU
max_length=150,
truncation=True
)
else:
# Para modelos de geração de texto
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=model_name,
device=-1, # CPU
max_length=100,
truncation=True,
pad_token_id=50256
)
model_cache[model_name] = pipe
return pipe
except Exception as e:
return None
def chat_with_model(message, history, model_name):
"""Conversa com o modelo selecionado"""
try:
# Carregar modelo
pipe = load_model(model_name)
if pipe is None:
return "❌ Erro ao carregar modelo. Tente outro modelo."
# Gerar resposta
if "blenderbot" in model_name.lower():
# Conversational pipeline
from transformers import Conversation
conversation = Conversation(message)
result = pipe(conversation)
response = result.generated_responses[-1]
else:
# Text generation pipeline
result = pipe(
message,
max_new_tokens=50,
temperature=0.7,
do_sample=True,
pad_token_id=50256,
eos_token_id=50256
)
generated_text = result[0]['generated_text']
# Remover o input original da resposta
if generated_text.startswith(message):
response = generated_text[len(message):].strip()
else:
response = generated_text.strip()
# Se resposta vazia, usar parte da geração
if not response:
response = generated_text[-100:].strip()
return response if response else "🤔 Modelo não gerou uma resposta clara."
except Exception as e:
return f"🚨 Erro: {str(e)[:100]}... Tente um modelo menor."
# Modelos otimizados para Spaces (menores e mais rápidos)
MODELOS_DISPONIVEIS = [
"distilgpt2",
"gpt2",
"microsoft/DialoGPT-small",
"facebook/blenderbot_small-90M"
]
def responder(message, history, modelo_selecionado):
if not message.strip():
return history, ""
# Adicionar mensagem do usuário
history.append([message, "🤖 Gerando resposta..."])
# Obter resposta
resposta = chat_with_model(message, history, modelo_selecionado)
# Atualizar com resposta real
history[-1][1] = resposta
return history, ""
def limpar_chat():
return []
def info_modelo(modelo):
info = {
"distilgpt2": "⚡ **DistilGPT-2** - Rápido e eficiente",
"gpt2": "📝 **GPT-2** - Geração de texto criativo",
"microsoft/DialoGPT-small": "💬 **DialoGPT** - Conversação natural",
"facebook/blenderbot_small-90M": "🤖 **BlenderBot** - Chat empático"
}
return info.get(modelo, "🤖 Modelo selecionado")
# Interface Gradio
with gr.Blocks(
title="🤖 Chat Multi-Modelo Local",
theme=gr.themes.Soft()
) as demo:
gr.Markdown("""
# 🤖 **Chat Multi-Modelo Local**
### ⚡ Modelos rodando diretamente no Hugging Face Spaces
""")
with gr.Row():
modelo = gr.Dropdown(
choices=MODELOS_DISPONIVEIS,
value="distilgpt2",
label="🎯 Escolha o Modelo",
info="Modelos otimizados para Spaces"
)
info_display = gr.Markdown(
info_modelo("distilgpt2")
)
chatbot = gr.Chatbot(
height=400,
label="💬 Conversa",
avatar_images=["👤", "🤖"]
)
with gr.Row():
msg = gr.Textbox(
placeholder="Digite sua mensagem...",
container=False,
scale=4
)
send = gr.Button("📤", scale=1, variant="primary")
clear = gr.Button("🗑️ Limpar Chat")
# Exemplos
gr.Examples(
examples=[
"Olá! Como você está?",
"Me conte uma piada",
"Qual é a capital do Brasil?",
"Escreva um poema curto"
],
inputs=msg
)
# Eventos
send.click(responder, [msg, chatbot, modelo], [chatbot, msg])
msg.submit(responder, [msg, chatbot, modelo], [chatbot, msg])
clear.click(limpar_chat, outputs=chatbot)
modelo.change(info_modelo, inputs=modelo, outputs=info_display)
gr.Markdown("""
---
### ℹ️ **Informações:**
- 🏃‍♂️ **Modelos locais** - Sem dependência de APIs externas
- ⚡ **Otimizados** - Para funcionar bem no Spaces
- 🔄 **Primeiro uso** - Pode demorar para carregar modelo
""")
if __name__ == "__main__":
demo.launch()