import gradio as gr import onnxruntime_genai as og import time import os from huggingface_hub import snapshot_download import argparse import logging import numpy as np # Import numpy # --- Logging Setup --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- Configuration --- MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx" # --- Defaulting to CPU INT4 for Hugging Face Spaces --- EXECUTION_PROVIDER = "cpu" # Corresponds to installing 'onnxruntime-genai' MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*" # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- # --- (Optional) Alternative GPU Configuration --- # EXECUTION_PROVIDER = "cuda" # Corresponds to installing 'onnxruntime-genai-cuda' # MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*" # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg" HF_MODEL_URL = f"https://huggingface.co/{MODEL_REPO}" ORT_GENAI_URL = "https://github.com/microsoft/onnxruntime-genai" PHI_LOGO_URL = "https://microsoft.github.io/phi/assets/img/logo-final.png" # Phi logo for bot avatar # Global variables for model and tokenizer model = None tokenizer = None model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display model_status = "Initializing..." # --- Model Download and Load --- def initialize_model(): """Downloads and loads the ONNX model and tokenizer.""" global model, tokenizer, model_status logging.info("--- Initializing ONNX Runtime GenAI ---") model_status = "Downloading model..." logging.info(model_status) # --- Download --- model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB)) if os.path.exists(model_variant_dir) and os.listdir(model_variant_dir): logging.info(f"Model variant found in {model_variant_dir}. Skipping download.") model_path = model_variant_dir else: logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...") try: snapshot_download( MODEL_REPO, allow_patterns=[MODEL_VARIANT_GLOB], local_dir=LOCAL_MODEL_DIR, local_dir_use_symlinks=False ) model_path = model_variant_dir logging.info(f"Model downloaded to: {model_path}") except Exception as e: logging.error(f"Error downloading model: {e}", exc_info=True) model_status = f"Error downloading model: {e}" raise RuntimeError(f"Failed to download model: {e}") # --- Load --- model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..." logging.info(model_status) try: # The simple constructor often works by detecting the installed ORT package. logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})") model = og.Model(model_path) # Simplified model loading tokenizer = og.Tokenizer(model) model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" logging.info("Model and Tokenizer loaded successfully.") except AttributeError as ae: logging.error(f"AttributeError during model/tokenizer init: {ae}", exc_info=True) logging.error("This might indicate an installation issue or version incompatibility with onnxruntime_genai.") model_status = f"Init Error: {ae}" raise RuntimeError(f"Failed to initialize model/tokenizer: {ae}") except Exception as e: logging.error(f"Error loading model or tokenizer: {e}", exc_info=True) model_status = f"Error loading model: {e}" raise RuntimeError(f"Failed to load model: {e}") # --- Generation Function (Core Logic) --- def generate_response_stream(prompt, history, max_length, temperature, top_p, top_k): """Generates a response using the Phi-4 ONNX model, yielding text chunks.""" global model_status if not model or not tokenizer: model_status = "Error: Model not initialized!" yield "Error: Model not initialized. Please check logs." return # --- Prepare the prompt using the Phi-4 instruct format --- full_prompt = "" # History format is [[user1, bot1], [user2, bot2], ...] for user_msg, assistant_msg in history: # history here is *before* the current prompt full_prompt += f"<|user|>\n{user_msg}<|end|>\n" if assistant_msg: # Append assistant message only if it exists full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n" # Add the current user prompt and the trigger for the assistant's response full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n" logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})") try: input_tokens_list = tokenizer.encode(full_prompt) # Encode returns a list/array # Ensure input_tokens is a numpy array of the correct type (int32 is common) input_tokens = np.array(input_tokens_list, dtype=np.int32) # Reshape to (batch_size, sequence_length), which is (1, N) for single prompt input_tokens = input_tokens.reshape((1, -1)) search_options = { "max_length": max_length, "temperature": temperature, "top_p": top_p, "top_k": top_k, "do_sample": True, } params = og.GeneratorParams(model) params.set_search_options(**search_options) # FIX: Create a dictionary mapping input names to tensors (numpy arrays) # and pass this dictionary to set_inputs. # Assuming the standard input name "input_ids". inputs = {"input_ids": input_tokens} logging.info(f"Setting inputs with keys: {inputs.keys()} and shape for 'input_ids': {inputs['input_ids'].shape}") params.set_inputs(inputs) start_time = time.time() # Create generator AFTER setting parameters including inputs generator = og.Generator(model, params) model_status = "Generating..." # Update status indicator logging.info("Streaming response...") first_token_time = None token_count = 0 # Rely primarily on generator.is_done() while not generator.is_done(): try: generator.compute_logits() generator.generate_next_token() if first_token_time is None: first_token_time = time.time() # Record time to first token next_token = generator.get_next_tokens()[0] decoded_chunk = tokenizer.decode([next_token]) token_count += 1 # Secondary check: Stop if the model explicitly generates the <|end|> string literal. if decoded_chunk == "<|end|>": logging.info("Assistant explicitly generated <|end|> token string.") break yield decoded_chunk # Yield just the text chunk except Exception as loop_error: logging.error(f"Error inside generation loop: {loop_error}", exc_info=True) yield f"\n\nError during token generation: {loop_error}" break # Exit loop on error end_time = time.time() ttft = (first_token_time - start_time) * 1000 if first_token_time else -1 total_time = end_time - start_time tps = (token_count / total_time) if total_time > 0 else 0 logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}") model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status except TypeError as te: # Catch type errors specifically during setup if the input format is still wrong logging.error(f"TypeError during generation setup: {te}", exc_info=True) logging.error("Check if the input format {'input_ids': token_array} is correct.") model_status = f"Generation Setup TypeError: {te}" yield f"\n\nSorry, a TypeError occurred setting up generation: {te}" except AttributeError as ae: # Catch potential future API changes or issues during generation setup logging.error(f"AttributeError during generation setup: {ae}", exc_info=True) model_status = f"Generation Setup Error: {ae}" yield f"\n\nSorry, an error occurred setting up generation: {ae}" except Exception as e: logging.error(f"Error during generation: {e}", exc_info=True) model_status = f"Error during generation: {e}" yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message # --- Gradio Interface Functions --- # 1. Function to add user message to chat history def add_user_message(user_message, history): """Adds the user's message to the chat history for display.""" if not user_message: return "", history # Clear input, return unchanged history history = history + [[user_message, None]] # Append user message, leave bot response None return "", history # Clear input textbox, return updated history # 2. Function to handle bot response generation and streaming def generate_bot_response(history, max_length, temperature, top_p, top_k): """Generates the bot's response based on the history and streams it.""" if not history or history[-1][1] is not None: return history user_prompt = history[-1][0] # Get the latest user prompt model_history = history[:-1] # Prepare history for the model response_stream = generate_response_stream( user_prompt, model_history, max_length, temperature, top_p, top_k ) history[-1][1] = "" # Initialize the bot response string in the history for chunk in response_stream: history[-1][1] += chunk # Append the chunk to the bot's message in history yield history # Yield the *entire updated history* back to Chatbot # 3. Function to clear chat def clear_chat(): """Clears the chat history and input.""" global model_status if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"): model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" return None, [], model_status # Clear Textbox, Chatbot history, and update status display # --- Initialize Model on App Start --- try: initialize_model() except Exception as e: print(f"FATAL: Model initialization failed: {e}") # --- Gradio Interface --- logging.info("Creating Gradio Interface...") theme = gr.themes.Soft( primary_hue="blue", secondary_hue="sky", neutral_hue="slate", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], ) with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo: # Header Section with gr.Row(equal_height=False): with gr.Column(scale=3): gr.Markdown(f""" # Phi-4 Mini Instruct ONNX Chat 🤖 Interact with the quantized `{model_variant_name}` version of [`{MODEL_REPO}`]({HF_MODEL_URL}) running efficiently via [`onnxruntime-genai`]({ORT_GENAI_URL}) ({EXECUTION_PROVIDER.upper()}). """) with gr.Column(scale=1, min_width=150): gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50) model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2) # Main Layout with gr.Row(): # Chat Column with gr.Column(scale=3): chatbot = gr.Chatbot( label="Conversation", height=600, layout="bubble", bubble_full_width=False, avatar_images=(None, PHI_LOGO_URL) ) with gr.Row(): prompt_input = gr.Textbox( label="Your Message", placeholder="<|user|>\nType your message here...\n<|end|>", lines=4, scale=9 ) with gr.Column(scale=1, min_width=120): submit_button = gr.Button("Send", variant="primary", size="lg") clear_button = gr.Button("đŸ—‘ī¸ Clear Chat", variant="secondary") # Settings Column with gr.Column(scale=1, min_width=250): gr.Markdown("### âš™ī¸ Generation Settings") with gr.Group(): max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.") temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random") top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.") top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).") gr.Markdown("---") gr.Markdown("â„šī¸ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`") gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.") # Event Listeners bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k] submit_event = prompt_input.submit( fn=add_user_message, inputs=[prompt_input, chatbot], outputs=[prompt_input, chatbot], queue=False, ).then( fn=generate_bot_response, inputs=bot_response_inputs, outputs=[chatbot], api_name="chat" ) submit_button.click( fn=add_user_message, inputs=[prompt_input, chatbot], outputs=[prompt_input, chatbot], queue=False, ).then( fn=generate_bot_response, inputs=bot_response_inputs, outputs=[chatbot], api_name=False ) clear_button.click( fn=clear_chat, inputs=None, outputs=[prompt_input, chatbot, model_status_text], queue=False ) # Launch the Gradio app logging.info("Launching Gradio App...") demo.queue(max_size=20) demo.launch(show_error=True, max_threads=40)