Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Model configuration - change this to your model path | |
MODEL_NAME = "DarwinAnim8or/TinyRP" | |
# Initialize model and tokenizer for CPU inference | |
print("Loading model for CPU inference...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float32, # Use float32 for CPU | |
device_map="cpu", | |
trust_remote_code=True | |
) | |
print(f"β Model loaded successfully on CPU: {MODEL_NAME}") | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
tokenizer = None | |
model = None | |
# Sample character presets | |
SAMPLE_CHARACTERS = { | |
"Custom Character": "", | |
"Adventurous Knight": "You are Sir Gareth, a brave and noble knight on a quest to save the kingdom. You speak with honor and courage, always ready to help those in need. You carry an enchanted sword and have a loyal horse named Thunder.", | |
"Mysterious Wizard": "You are Eldara, an ancient and wise wizard who speaks in riddles and knows secrets of the mystical arts. You live in a tower filled with magical books and potions. You are helpful but often cryptic in your responses.", | |
"Friendly Tavern Keeper": "You are Bram, a cheerful tavern keeper who loves telling stories and meeting new travelers. Your tavern 'The Dancing Dragon' is a warm, welcoming place. You know all the local gossip and always have a tale to share.", | |
"Curious Scientist": "You are Dr. Maya Chen, a brilliant scientist who is fascinated by discovery and invention. You're enthusiastic about explaining complex concepts in simple ways and always looking for new experiments to try.", | |
"Space Explorer": "You are Captain Nova, a fearless space explorer who has traveled to distant galaxies. You pilot the starship 'Wanderer' and have encountered many alien species. You're brave, curious, and always ready for the next adventure.", | |
"Fantasy Princess": "You are Princess Lyra, kind-hearted royalty who cares deeply about her people. You're intelligent, diplomatic, and skilled in both politics and magic. You often sneak out of the castle to help citizens in need." | |
} | |
def build_chatml_conversation(message, history, character_description): | |
"""Build a conversation in ChatML format""" | |
conversation = "" | |
# Add system message if character is defined | |
if character_description.strip(): | |
conversation += f"<|im_start|>system\n{character_description.strip()}<|im_end|>\n" | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
conversation += f"<|im_start|>user\n{user_msg}<|im_end|>\n" | |
if assistant_msg: | |
conversation += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" | |
# Add current user message | |
conversation += f"<|im_start|>user\n{message}<|im_end|>\n" | |
# Start assistant response | |
conversation += "<|im_start|>assistant\n" | |
return conversation | |
def generate_cpu_response(message, history, character_description, max_tokens, temperature, top_p, repetition_penalty): | |
"""Generate response using local CPU inference with ChatML format""" | |
if model is None or tokenizer is None: | |
return "β Error: Model not loaded properly. Please check the model path." | |
if not message.strip(): | |
return "Please enter a message." | |
try: | |
# Build ChatML conversation | |
conversation = build_chatml_conversation(message, history, character_description) | |
# Tokenize the conversation | |
inputs = tokenizer.encode( | |
conversation, | |
return_tensors="pt", | |
truncation=True, | |
max_length=1024 - max_tokens # Leave room for response | |
) | |
print(f"π Generating response... (Input length: {inputs.shape[1]} tokens)") | |
# Generate response on CPU | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=int(max_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
repetition_penalty=float(repetition_penalty), | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
use_cache=True, | |
num_return_sequences=1 | |
) | |
# Decode the full response | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
# Extract just the assistant's response from ChatML format | |
if "<|im_start|>assistant\n" in full_response: | |
# Split on the last assistant tag to get only the new response | |
assistant_parts = full_response.split("<|im_start|>assistant\n") | |
if len(assistant_parts) > 1: | |
response = assistant_parts[-1] | |
# Remove any trailing <|im_end|> or other tokens | |
response = response.replace("<|im_end|>", "").strip() | |
# Clean up any remaining special tokens | |
response = response.replace("<|im_start|>", "").replace("<|im_end|>", "") | |
response = response.replace("<s>", "").replace("</s>", "") | |
response = response.strip() | |
if response: | |
print(f"β Generated {len(response)} characters") | |
return response | |
# Fallback: try to extract response after the input | |
input_text = tokenizer.decode(inputs[0], skip_special_tokens=False) | |
if len(full_response) > len(input_text): | |
response = full_response[len(input_text):].strip() | |
# Clean special tokens | |
response = response.replace("<|im_start|>", "").replace("<|im_end|>", "") | |
response = response.replace("<s>", "").replace("</s>", "") | |
response = response.strip() | |
if response: | |
return response | |
return "Sorry, I couldn't generate a proper response. Please try again." | |
except Exception as e: | |
print(f"β Generation error: {e}") | |
return f"Error generating response: {str(e)}" | |
def load_character_preset(character_name): | |
"""Load a character preset description""" | |
return SAMPLE_CHARACTERS.get(character_name, "") | |
def chat_function(message, history, character_description, max_tokens, temperature, top_p, repetition_penalty): | |
"""Main chat function that handles the conversation flow""" | |
if not message.strip(): | |
return history, "" | |
# Generate response using CPU inference | |
response = generate_cpu_response( | |
message, | |
history, | |
character_description, | |
max_tokens, | |
temperature, | |
top_p, | |
repetition_penalty | |
) | |
# Add to history | |
history.append([message, response]) | |
return history, "" | |
# Custom CSS for better styling | |
css = """ | |
.character-card { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
border-radius: 15px; | |
padding: 20px; | |
margin: 10px 0; | |
color: white; | |
} | |
.title-text { | |
text-align: center; | |
font-size: 2.5em; | |
font-weight: bold; | |
background: linear-gradient(45deg, #667eea, #764ba2); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
margin-bottom: 20px; | |
} | |
.parameter-box { | |
background: #f8f9fa; | |
border-radius: 10px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
.cpu-badge { | |
background: #28a745; | |
color: white; | |
padding: 5px 10px; | |
border-radius: 15px; | |
font-size: 0.8em; | |
margin-left: 10px; | |
} | |
""" | |
# Create the Gradio interface | |
with gr.Blocks(css=css, title="TinyRP Chat Demo") as demo: | |
gr.HTML('<div class="title-text">π TinyRP Character Chat <span class="cpu-badge">CPU Inference</span></div>') | |
gr.Markdown(""" | |
### Welcome to TinyRP! | |
This is a demo of a small but capable roleplay model running on CPU. Choose a character preset or create your own! | |
**Tips for better roleplay:** | |
- Be descriptive in your messages | |
- Stay in character | |
- Uses ChatML format for best results | |
- Adjust temperature for creativity vs consistency | |
β‘ **Running on CPU** - Responses may take 10-30 seconds depending on your hardware. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Chat interface | |
chatbot = gr.Chatbot( | |
label="Chat", | |
height=500, | |
show_label=False, | |
avatar_images=("π§", "π") | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2, | |
scale=4 | |
) | |
send_btn = gr.Button("Send", variant="primary", scale=1) | |
with gr.Column(scale=1): | |
# Character selection | |
with gr.Group(): | |
gr.Markdown("### π Character Setup") | |
character_preset = gr.Dropdown( | |
choices=list(SAMPLE_CHARACTERS.keys()), | |
value="Custom Character", | |
label="Character Presets", | |
interactive=True | |
) | |
character_description = gr.Textbox( | |
label="Character Description", | |
placeholder="Describe your character's personality, background, and speaking style...", | |
lines=6, | |
value="" | |
) | |
load_preset_btn = gr.Button("Load Preset", variant="secondary") | |
# Generation parameters | |
with gr.Group(): | |
gr.Markdown("### βοΈ Generation Settings") | |
gr.Markdown("*Using ChatML format automatically*") | |
max_tokens = gr.Slider( | |
minimum=16, | |
maximum=256, | |
value=100, | |
step=16, | |
label="Max Response Length", | |
info="Longer = more detailed responses (slower on CPU)" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.9, | |
step=0.1, | |
label="Temperature", | |
info="Higher = more creative/random" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.85, | |
step=0.05, | |
label="Top-p", | |
info="Focus on top % of likely words" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
value=1.1, | |
step=0.05, | |
label="Repetition Penalty", | |
info="Reduce repetitive text" | |
) | |
# Control buttons | |
with gr.Group(): | |
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
# Sample character cards | |
with gr.Row(): | |
gr.Markdown("### π Featured Characters") | |
with gr.Row(): | |
for char_name, char_desc in list(SAMPLE_CHARACTERS.items())[1:4]: # Show first 3 non-custom | |
with gr.Column(scale=1): | |
gr.Markdown(f""" | |
<div class="character-card"> | |
<h4>{char_name}</h4> | |
<p>{char_desc[:100]}...</p> | |
</div> | |
""") | |
# Event handlers | |
send_btn.click( | |
chat_function, | |
inputs=[msg, chatbot, character_description, max_tokens, temperature, top_p, repetition_penalty], | |
outputs=[chatbot, msg] | |
) | |
msg.submit( | |
chat_function, | |
inputs=[msg, chatbot, character_description, max_tokens, temperature, top_p, repetition_penalty], | |
outputs=[chatbot, msg] | |
) | |
load_preset_btn.click( | |
load_character_preset, | |
inputs=[character_preset], | |
outputs=[character_description] | |
) | |
character_preset.change( | |
load_character_preset, | |
inputs=[character_preset], | |
outputs=[character_description] | |
) | |
clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch() |