daqc's picture
Update app.py
f68f560 verified
raw
history blame
6.21 kB
import os
import gradio as gr
import torch
import torch._dynamo
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import spaces
# Desactivar TorchDynamo para evitar errores de compilaci贸n
torch._dynamo.config.suppress_errors = True
torch._dynamo.disable()
# Configuraci贸n
MODEL_ID = "somosnlp-hackathon-2025/iberotales-gemma-3-1b-it-es"
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048"))
# System prompt personalizado
DEFAULT_SYSTEM_MESSAGE = """Resuelve el siguiente problema.
Primero, piensa en voz alta qu茅 debes hacer, paso por paso y de forma resumida, entre <think> y </think>.
Luego, da la respuesta final entre <SOLUTION> y </SOLUTION>.
No escribas nada fuera de ese formato."""
# Variables globales
model = None
tokenizer = None
def load_model():
"""Cargar modelo y tokenizador"""
global model, tokenizer
if torch.cuda.is_available():
print(f"Cargando modelo: {MODEL_ID}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("隆Modelo cargado exitosamente!")
return True
except Exception as e:
print(f"Error al cargar el modelo: {e}")
return False
else:
print("CUDA no disponible")
return False
# Cargar modelo al iniciar
model_loaded = load_model()
@spaces.GPU
def generate(
message: str,
history: list,
system_message: str,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.2,
):
"""Generar historia con streaming"""
global model, tokenizer
if model is None or tokenizer is None:
yield "Error: Modelo no disponible. Por favor, reinicia la aplicaci贸n."
return
conversation = []
if system_message:
conversation.append({"role": "system", "content": system_message})
for msg in history:
if isinstance(msg, dict) and "role" in msg and "content" in msg:
conversation.append(msg)
conversation.append({"role": "user", "content": message})
try:
input_ids = tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True,
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Conversaci贸n recortada a {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=30.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
generation_thread = Thread(target=model.generate, kwargs=generate_kwargs)
generation_thread.start()
outputs = []
try:
for new_text in streamer:
outputs.append(new_text)
yield "".join(outputs)
except Exception as e:
yield f"Error durante la generaci贸n: {str(e)}"
finally:
generation_thread.join(timeout=1)
except Exception as e:
yield f"Error: {str(e)}"
# Crear interfaz de chat
demo = gr.ChatInterface(
fn=generate,
title="Iberotales: Mitos y Leyendas Iberoamericanas",
description="Genera historias y personajes basados en el patrimonio cultural de Iberoam茅rica usando GRPO.",
chatbot=gr.Chatbot(
height=600,
show_copy_button=True,
),
textbox=gr.Textbox(
placeholder="Escribe una historia o personaje que quieras generar...",
scale=7
),
additional_inputs=[
gr.Textbox(
value=DEFAULT_SYSTEM_MESSAGE,
label="Mensaje del sistema (formato estructurado requerido)"
),
gr.Slider(
label="M谩ximo de tokens",
minimum=100,
maximum=MAX_MAX_NEW_TOKENS,
step=50,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperatura",
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.7,
),
gr.Slider(
label="Top-p",
minimum=0.1,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=100,
step=1,
value=50,
),
gr.Slider(
label="Penalizaci贸n por repetici贸n",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["Crea una historia corta sobre el Pombero, un personaje de la mitolog铆a guaran铆."],
["Genera un personaje basado en la leyenda del Cadejo."],
["Inventa una narrativa en torno al Nahual en un entorno contempor谩neo."],
],
cache_examples=False,
)
if __name__ == "__main__":
if model_loaded:
print("Lanzando aplicaci贸n Gradio...")
demo.launch(
share=False,
show_error=True
)
else:
print("Error al cargar el modelo. No se puede iniciar la aplicaci贸n.")