merterbak's picture
Harmony attempt #1 blended with simple formatting
ae0ab06 verified
raw
history blame
5.11 kB
from transformers import pipeline, TextIteratorStreamer
import torch
from threading import Thread
import gradio as gr
import spaces
import re
from openai_harmony import (
load_harmony_encoding,
HarmonyEncodingName,
Role,
Message,
Conversation,
)
model_id = "openai/gpt-oss-20b"
pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def format_conversation_history(chat_history):
messages = []
for item in chat_history:
role = item["role"]
content = item["content"]
if isinstance(content, list):
content = content[0]["text"] if content and "text" in content[0] else str(content)
messages.append({"role": role, "content": content})
return messages
#OpenAI's harmony format
def build_harmony_conversation_from_messages(messages):
harmony_messages = []
for m in messages:
role = m["role"].lower()
content = m["content"]
if role == "system":
harmony_messages.append(
Message.from_role_and_content(
Role.SYSTEM,
content,
)
)
elif role == "user":
harmony_messages.append(
Message.from_role_and_content(
Role.USER,
content,
)
)
elif role == "assistant":
harmony_messages.append(
Message.from_role_and_content(
Role.ASSISTANT,
content,
)
)
return Conversation.from_messages(harmony_messages)
@spaces.GPU()
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
new_message = {"role": "user", "content": input_data}
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history + [new_message]
conversation = build_harmony_conversation_from_messages(messages)
prompt_tokens = enc.render_conversation_for_completion(conversation, Role.ASSISTANT)
prompt_text = pipe.tokenizer.decode(prompt_tokens, skip_special_tokens=False)
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"streamer": streamer,
"return_full_text": False,
}
thread = Thread(target=pipe, args=(prompt_text,), kwargs=generation_kwargs)
thread.start()
thinking = ""
final = ""
started_final = False
for chunk in streamer:
if not started_final:
if "assistantfinal" in chunk.lower():
split_parts = re.split(r'(?i)assistantfinal', chunk, maxsplit=1)
thinking += split_parts[0]
final += split_parts[1]
started_final = True
else:
thinking += chunk
else:
final += chunk
clean_thinking = re.sub(r'^analysis\s*', '', thinking, flags=re.I).strip()
clean_final = final.strip()
formatted = f"<details open><summary>Click to view Thinking Process</summary>\n\n{clean_thinking}\n\n</details>\n\n{clean_final}"
yield formatted
demo = gr.ChatInterface(
fn=generate_response,
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048),
gr.Textbox(
label="System Prompt",
value="You are a helpful assistant. Reasoning: medium",
lines=4,
placeholder="Change system prompt"
),
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
],
examples=[
[{"text": "Explain Newton laws clearly and concisely"}],
[{"text": "What are the benefits of open weight AI models"}],
[{"text": "Write a Python function to calculate the Fibonacci sequence"}],
],
cache_examples=False,
type="messages",
description="""# gpt-oss-20b Demo
Give it a couple of seconds to start. You can adjust reasoning level in the system prompt like "Reasoning: high." Click to view thinking process (default is on).""",
fill_height=True,
textbox=gr.Textbox(
label="Query Input",
placeholder="Type your prompt"
),
stop_btn="Stop Generation",
multimodal=False,
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.launch()