k8o1 / app.py
robiro's picture
Update app.py
8227d8e verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
# --- Konfiguration ---
MODEL_ID = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
HF_TOKEN = os.getenv("HF_TOKEN") # Optional: Für private Modelle oder Zugriffsbeschränkungen
# --- Lade Modell und Tokenizer (explizit auf CPU) ---
print(f"Lade Tokenizer: {MODEL_ID}")
# Stelle sicher, dass trust_remote_code=True gesetzt ist, da Qwen3 dies oft benötigt
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
if tokenizer.pad_token is None:
print("pad_token nicht gesetzt, verwende eos_token als pad_token.")
tokenizer.pad_token = tokenizer.eos_token
print(f"Lade Modell: {MODEL_ID} auf CPU. Dies kann einige Zeit dauern...")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu",
trust_remote_code=True,
token=HF_TOKEN
)
except Exception as e:
print(f"Fehler beim Laden mit bfloat16 ({e}), versuche float32...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
token=HF_TOKEN
)
model.eval()
print("Modell und Tokenizer erfolgreich geladen.")
# --- Vorhersagefunktion für das ChatInterface ---
def predict(message, history):
messages_for_template = []
for user_msg, ai_msg in history: # history ist jetzt eine Liste von Listen/Tupeln
messages_for_template.append({"role": "user", "content": user_msg})
messages_for_template.append({"role": "assistant", "content": ai_msg})
messages_for_template.append({"role": "user", "content": message})
try:
prompt = tokenizer.apply_chat_template(
messages_for_template,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
print(f"Fehler beim Anwenden des Chat-Templates: {e}")
prompt_parts = []
for turn in messages_for_template:
prompt_parts.append(f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>")
prompt = "\n".join(prompt_parts) + "\n<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to("cpu")
generation_kwargs = {
"max_new_tokens": 512,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
print("Generiere Antwort...")
with torch.no_grad():
outputs = model.generate(**inputs, **generation_kwargs)
response_ids = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
print(f"Antwort: {response}")
return response
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek Qwen3 8B (CPU)") as demo:
gr.Markdown(
"""
# DeepSeek Qwen3 8B Chat (CPU)
Dies ist eine Demo des `deepseek-ai/DeepSeek-R1-0528-Qwen3-8B` Modells, das auf einer CPU läuft.
**Achtung:** Antworten können aufgrund der CPU-Inferenz **sehr langsam** sein (mehrere Minuten pro Antwort sind möglich).
Bitte habe Geduld.
"""
)
chatbot_interface = gr.ChatInterface(
fn=predict,
chatbot=gr.Chatbot(
height=600,
label="Chat",
show_label=False,
# bubble_full_width=False, # Entfernt, da veraltet
# type="messages" # Wichtig, um die Warnung zu beheben, aber history-Format in predict() muss passen
# Da predict bereits die history als [[user, ai], [user, ai]] erwartet (Standard für ChatInterface),
# lassen wir type hier weg, damit es mit dem Format von predict harmoniert.
# Wenn predict `history` als [{"role": "user", ...}, {"role": "assistant", ...}] erwarten würde,
# dann wäre `type="messages"` hier richtig.
# Da die Warnung sich auf die Standardeinstellung bezieht, die bald "messages" sein wird,
# und unsere predict-Funktion bereits das "tuples"-Format verarbeitet, ist das OK für jetzt.
# Man könnte predict anpassen, um das "messages" Format direkt zu verarbeiten, wenn man type="messages" setzt.
),
textbox=gr.Textbox(
placeholder="Stelle mir eine Frage...",
container=False,
scale=7
),
examples=[
["Hallo, wer bist du?"],
["Was ist die Hauptstadt von Frankreich?"],
["Schreibe ein kurzes Gedicht über KI."]
],
# Entferne die nicht unterstützten Button-Argumente:
# retry_btn="Wiederholen",
# undo_btn="Letzte entfernen",
# clear_btn="Chat löschen",
)
gr.Markdown("Modell von [deepseek-ai](https://huggingface.co/deepseek-ai) auf Hugging Face.")
if __name__ == "__main__":
demo.launch()