import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM titulo = """# 🤖 Bienvenido al Chatbot con Yi-9B""" descripcion = """Este chatbot utiliza el modelo Yi de 9B parámetros para generar respuestas. Puedes mantener una conversación fluida y realizar preguntas sobre diversos temas.""" # Definir el dispositivo y la ruta del modelo dispositivo = "cuda" if torch.cuda.is_available() else "cpu" ruta_modelo = "01-ai/Yi-9B-Chat" # Cargar el tokenizador y el modelo tokenizador = AutoTokenizer.from_pretrained(ruta_modelo) modelo = AutoModelForCausalLM.from_pretrained(ruta_modelo, device_map="auto").eval() def generar_respuesta(historial, usuario_input, max_longitud): mensajes = [ {"role": "system", "content": "Eres un asistente útil y amigable. Proporciona respuestas claras y concisas."} ] for entrada in historial: mensajes.append({"role": "user", "content": entrada[0]}) mensajes.append({"role": "assistant", "content": entrada[1]}) mensajes.append({"role": "user", "content": usuario_input}) texto = tokenizador.apply_chat_template( mensajes, tokenize=False, add_generation_prompt=True ) entradas_modelo = tokenizador([texto], return_tensors="pt").to(dispositivo) ids_generados = modelo.generate( entradas_modelo.input_ids, max_new_tokens=max_longitud, eos_token_id=tokenizador.eos_token_id ) ids_generados = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(entradas_modelo.input_ids, ids_generados) ] respuesta = tokenizador.batch_decode(ids_generados, skip_special_tokens=True)[0] historial.append((usuario_input, respuesta)) return historial, "" def interfaz_gradio(): with gr.Blocks() as interfaz: gr.Markdown(titulo) gr.Markdown(descripcion) chatbot = gr.Chatbot(label="Historial de chat") msg = gr.Textbox(label="Tu mensaje") clear = gr.Button("Limpiar") max_longitud_slider = gr.Slider(minimum=1, maximum=1000, value=500, label="Longitud máxima de la respuesta") msg.submit(generar_respuesta, [chatbot, msg, max_longitud_slider], [chatbot, msg]) clear.click(lambda: None, None, chatbot, queue=False) return interfaz if __name__ == "__main__": interfaz = interfaz_gradio() interfaz.queue() interfaz.launch()