Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient as HubInferenceClient # Renamed to avoid conflict | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
# Smolagents imports | |
from smolagents import CodeAgent, Tool, LiteLLMModel, OpenAIServerModel, TransformersModel, InferenceClientModel as SmolInferenceClientModel | |
from smolagents.gradio_ui import stream_to_gradio | |
ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
print("Access token loaded.") | |
# Function to encode image to base64 | |
def encode_image(image_path): | |
if not image_path: | |
print("No image path provided") | |
return None | |
try: | |
print(f"Encoding image from path: {image_path}") | |
# If it's already a PIL Image | |
if isinstance(image_path, Image.Image): | |
image = image_path | |
else: | |
# Try to open the image file | |
image = Image.open(image_path) | |
# Convert to RGB if image has an alpha channel (RGBA) | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
# Encode to base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
print("Image encoded successfully") | |
return img_str | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
return None | |
# --- Smolagents Tool Definition --- | |
try: | |
image_generation_tool = Tool.from_space( | |
"black-forest-labs/FLUX.1-schnell", | |
name="image_generator", | |
description="Generates an image from a textual prompt. Use this tool if the user asks to generate, create, or draw an image.", | |
token=ACCESS_TOKEN # Pass token if the space might be private or has rate limits | |
) | |
print("Image generation tool loaded successfully.") | |
SMOLAGENTS_TOOLS = [image_generation_tool] | |
except Exception as e: | |
print(f"Error loading image generation tool: {e}. Proceeding without it.") | |
SMOLAGENTS_TOOLS = [] | |
def respond( | |
message, | |
image_files, # Changed parameter name and structure | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
frequency_penalty, | |
seed, | |
provider, | |
custom_api_key, | |
custom_model, | |
model_search_term, | |
selected_model | |
): | |
print(f"Received message: {message}") | |
print(f"Received {len(image_files) if image_files else 0} images") | |
# print(f"History: {history}") # Can be very verbose | |
print(f"System message: {system_message}") | |
print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}") | |
print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}") | |
print(f"Selected provider: {provider}") | |
print(f"Custom API Key provided: {bool(custom_api_key.strip())}") | |
print(f"Selected model (custom_model): {custom_model}") | |
print(f"Model search term: {model_search_term}") | |
print(f"Selected model from radio: {selected_model}") | |
# Determine which token to use | |
token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN | |
if custom_api_key.strip() != "": | |
print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication") | |
else: | |
print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication") | |
# Determine which model to use, prioritizing custom_model if provided | |
model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model | |
print(f"Model selected for LLM: {model_to_use}") | |
# Prepare parameters for the LLM | |
llm_parameters = { | |
"max_tokens": max_tokens, # For LiteLLMModel, OpenAIServerModel | |
"max_new_tokens": max_tokens, # For TransformersModel, InferenceClientModel | |
"temperature": temperature, | |
"top_p": top_p, | |
"frequency_penalty": frequency_penalty, | |
} | |
if seed != -1: | |
llm_parameters["seed"] = seed | |
# Initialize the smolagents Model | |
# For simplicity, we'll use InferenceClientModel if provider is hf-inference, | |
# otherwise LiteLLMModel which supports many providers. | |
# You might want to add more sophisticated logic to select the right smolagents Model class. | |
if provider == "hf-inference" or provider is None or provider == "": # provider can be None if custom_model is a URL | |
smol_model = SmolInferenceClientModel( | |
model_id=model_to_use, | |
token=token_to_use, | |
provider=provider if provider else None, # Pass provider only if it's explicitly set and not hf-inference default | |
**llm_parameters | |
) | |
print(f"Using SmolInferenceClientModel for LLM with provider: {provider or 'default'}") | |
else: | |
# Assuming other providers might be LiteLLM compatible | |
# LiteLLM uses `model` for model_id and `api_key` for token | |
smol_model = LiteLLMModel( | |
model_id=f"{provider}/{model_to_use}" if provider else model_to_use, # LiteLLM often expects provider/model_name | |
api_key=token_to_use, | |
**llm_parameters | |
) | |
print(f"Using LiteLLMModel for LLM with provider: {provider}") | |
# Initialize smolagent | |
# We'll use CodeAgent as it's generally more powerful. | |
# The system_message from the UI will be part of the task for the agent. | |
agent_task = message | |
if system_message and system_message.strip(): | |
agent_task = f"System Instructions: {system_message}\n\nUser Task: {message}" | |
print(f"Initializing CodeAgent with model: {model_to_use}") | |
agent = CodeAgent( | |
tools=SMOLAGENTS_TOOLS, # Use the globally defined tools | |
model=smol_model, | |
stream_outputs=True # Important for streaming | |
) | |
print("CodeAgent initialized.") | |
# Prepare multimodal inputs for the agent if images are present | |
agent_images = [] | |
if image_files and len(image_files) > 0: | |
for img_path in image_files: | |
if img_path: | |
try: | |
# Smolagents expects PIL Image objects for images | |
pil_image = Image.open(img_path) | |
agent_images.append(pil_image) | |
except Exception as e: | |
print(f"Error opening image for agent: {e}") | |
print(f"Prepared {len(agent_images)} images for the agent.") | |
# Start with an empty string to build the response as tokens stream in | |
response_text = "" | |
print(f"Running agent with task: {agent_task}") | |
try: | |
# Use stream_to_gradio for handling agent's streaming output | |
# The history needs to be converted to the format smolagents expects if we want to continue conversations. | |
# For now, we'll pass reset=True to simplify, meaning each call is a new conversation for the agent. | |
# To support conversation history with the agent, `history` needs to be transformed into agent.memory.steps | |
# or passed appropriately. The `stream_to_gradio` function expects the agent's internal stream. | |
# Simplified history for agent (if needed, but stream_to_gradio handles Gradio's history) | |
# For `agent.run`, we don't directly pass Gradio's history. | |
# `stream_to_gradio` will yield messages that Gradio's chatbot can append. | |
# The `stream_to_gradio` function itself is a generator. | |
# It takes the agent and task, and yields Gradio-compatible chat messages. | |
# The `bot` function in Gradio needs to yield these messages. | |
# The `respond` function is already a generator, so we can yield from `stream_to_gradio`. | |
# Gradio's history (list of tuples) is not directly used by agent.run() | |
# Instead, the agent's own memory would handle conversational context if reset=False. | |
# Here, we'll let stream_to_gradio handle the output formatting. | |
print("Streaming response from agent...") | |
for content_chunk in stream_to_gradio( | |
agent, | |
task=agent_task, | |
task_images=agent_images if agent_images else None, | |
reset_agent_memory=True # For simplicity, treat each interaction as new for the agent | |
): | |
# stream_to_gradio yields either a string (for text delta) or a ChatMessage object | |
if isinstance(content_chunk, str): # This is a text delta | |
response_text += content_chunk | |
yield response_text | |
elif hasattr(content_chunk, 'content'): # This is a ChatMessage object | |
if isinstance(content_chunk.content, dict) and 'path' in content_chunk.content: # Image/Audio | |
# Gradio's chatbot can handle dicts for files directly if msg.submit is used | |
# For streaming, we yield the path or a markdown representation | |
yield f"" | |
elif isinstance(content_chunk.content, str): | |
response_text = content_chunk.content # Replace if it's a full message | |
yield response_text | |
else: # Should not happen with stream_to_gradio's typical output | |
print(f"Unexpected chunk type from stream_to_gradio: {type(content_chunk)}") | |
yield str(content_chunk) | |
print("\nCompleted response generation from agent.") | |
except Exception as e: | |
print(f"Error during agent execution: {e}") | |
response_text += f"\nError: {str(e)}" | |
yield response_text | |
# Function to validate provider selection based on BYOK | |
def validate_provider(api_key, provider): | |
if not api_key.strip() and provider != "hf-inference": | |
return gr.update(value="hf-inference") | |
return gr.update(value=provider) | |
# GRADIO UI | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
# Create the chatbot component | |
chatbot = gr.Chatbot( | |
height=600, | |
show_copy_button=True, | |
placeholder="Select a model and begin chatting. Now supports multiple inference providers, multimodal inputs, and image generation tool.", | |
layout="panel", | |
show_share_button=True # Added for easy sharing | |
) | |
print("Chatbot interface created.") | |
# Multimodal textbox for messages (combines text and file uploads) | |
msg = gr.MultimodalTextbox( | |
placeholder="Type a message or upload images... (e.g., 'generate an image of a cat playing chess')", | |
show_label=False, | |
container=False, | |
scale=12, | |
file_types=["image"], | |
file_count="multiple", | |
sources=["upload"] | |
) | |
# Create accordion for settings | |
with gr.Accordion("Settings", open=False): | |
# System message | |
system_message_box = gr.Textbox( | |
value="You are a helpful AI assistant that can understand images and text. If asked to generate an image, use the available image_generator tool.", | |
placeholder="You are a helpful assistant.", | |
label="System Prompt" | |
) | |
# Generation parameters | |
with gr.Row(): | |
with gr.Column(): | |
max_tokens_slider = gr.Slider( | |
minimum=1, | |
maximum=4096, | |
value=1024, # Increased default for potentially longer agent outputs | |
step=1, | |
label="Max tokens" | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-P" | |
) | |
with gr.Column(): | |
frequency_penalty_slider = gr.Slider( | |
minimum=-2.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label="Frequency Penalty" | |
) | |
seed_slider = gr.Slider( | |
minimum=-1, | |
maximum=65535, | |
value=-1, | |
step=1, | |
label="Seed (-1 for random)" | |
) | |
# Provider selection | |
providers_list = [ | |
"hf-inference", # Default Hugging Face Inference | |
"cerebras", # Cerebras provider | |
"together", # Together AI | |
"sambanova", # SambaNova | |
"novita", # Novita AI | |
"cohere", # Cohere | |
"fireworks-ai", # Fireworks AI | |
"hyperbolic", # Hyperbolic | |
"nebius", # Nebius | |
# Add other providers supported by LiteLLM if desired | |
] | |
provider_radio = gr.Radio( | |
choices=providers_list, | |
value="hf-inference", | |
label="Inference Provider", | |
) | |
# New BYOK textbox | |
byok_textbox = gr.Textbox( | |
value="", | |
label="BYOK (Bring Your Own Key)", | |
info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used. For other providers, this key will be used as their respective API key.", | |
placeholder="Enter your API token", | |
type="password" # Hide the API key for security | |
) | |
# Custom model box | |
custom_model_box = gr.Textbox( | |
value="", | |
label="Custom Model", | |
info="(Optional) Provide a custom Hugging Face model path (e.g., 'meta-llama/Llama-3.3-70B-Instruct') or a model name compatible with the selected provider. Overrides any selected featured model.", | |
placeholder="meta-llama/Llama-3.3-70B-Instruct" | |
) | |
# Model search | |
model_search_box = gr.Textbox( | |
label="Filter Models", | |
placeholder="Search for a featured model...", | |
lines=1 | |
) | |
# Featured models list | |
models_list = [ | |
"meta-llama/Llama-3.2-11B-Vision-Instruct", | |
"meta-llama/Llama-3.3-70B-Instruct", | |
"meta-llama/Llama-3.1-70B-Instruct", | |
"meta-llama/Llama-3.0-70B-Instruct", | |
"meta-llama/Llama-3.2-3B-Instruct", | |
"meta-llama/Llama-3.2-1B-Instruct", | |
"meta-llama/Llama-3.1-8B-Instruct", | |
"NousResearch/Hermes-3-Llama-3.1-8B", | |
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
"mistralai/Mistral-Nemo-Instruct-2407", | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"mistralai/Mistral-7B-Instruct-v0.3", | |
"mistralai/Mistral-7B-Instruct-v0.2", | |
"Qwen/Qwen3-235B-A22B", | |
"Qwen/Qwen3-32B", | |
"Qwen/Qwen2.5-72B-Instruct", | |
"Qwen/Qwen2.5-3B-Instruct", | |
"Qwen/Qwen2.5-0.5B-Instruct", | |
"Qwen/QwQ-32B", | |
"Qwen/Qwen2.5-Coder-32B-Instruct", | |
"microsoft/Phi-3.5-mini-instruct", | |
"microsoft/Phi-3-mini-128k-instruct", | |
"microsoft/Phi-3-mini-4k-instruct", | |
] | |
featured_model_radio = gr.Radio( | |
label="Select a model below (or specify a custom one above)", | |
choices=models_list, | |
value="meta-llama/Llama-3.2-11B-Vision-Instruct", # Default to a multimodal model | |
interactive=True | |
) | |
gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)") | |
# Chat history state | |
chat_history = gr.State([]) | |
# Function to filter models | |
def filter_models(search_term): | |
print(f"Filtering models with search term: {search_term}") | |
filtered = [m for m in models_list if search_term.lower() in m.lower()] | |
print(f"Filtered models: {filtered}") | |
return gr.update(choices=filtered) | |
# Function to set custom model from radio (actually, sets the selected_model which is then overridden by custom_model_box if filled) | |
def set_selected_model_from_radio(selected): | |
print(f"Featured model selected: {selected}") | |
# This function's output will be one of the inputs to `respond` | |
return selected | |
# Function for the chat interface | |
def user(user_message_input, history): | |
# user_message_input is a dict from MultimodalTextbox: {"text": str, "files": list[str]} | |
print(f"User input received: {user_message_input}") | |
text_content = user_message_input.get("text", "").strip() | |
files = user_message_input.get("files", []) | |
if not text_content and not files: | |
print("Empty message, skipping history update.") | |
return history # Or gr.skip() if Gradio version supports it well | |
# Append to Gradio's history format | |
# For multimodal, Gradio expects a list of (text, file_path) tuples or (None, file_path) | |
# We will represent this as a single user turn which might have text and multiple images. | |
# The `respond` function will then parse this. | |
# Gradio's Chatbot can display images if the message is a tuple (None, filepath) | |
# or if text contains markdown like  | |
current_turn_display = [] | |
if text_content: | |
current_turn_display.append(text_content) | |
if files: | |
for file_path in files: | |
current_turn_display.append((file_path,)) # Tuple for Gradio to recognize as file | |
if not current_turn_display: # Should not happen if we check above | |
return history | |
# For simplicity in history, we'll just append the text and a note about images. | |
# The actual image data is passed separately to `respond`. | |
display_message = text_content | |
if files: | |
display_message += f" ({len(files)} image(s) uploaded)" | |
history.append([display_message, None]) | |
return history | |
# Define bot response function | |
def bot(history, system_msg, max_tokens_val, temperature_val, top_p_val, freq_penalty_val, seed_val, provider_val, api_key_val, custom_model_val, search_term_val, selected_model_val, request: gr.Request): | |
if not history or not history[-1][0]: # If no user message | |
yield history | |
return | |
# The user's latest input is in history[-1][0] | |
# The MultimodalTextbox sends a dict: {"text": str, "files": list[str]} | |
# However, our `user` function above simplifies this for display in `chatbot`. | |
# We need to retrieve the original input from the request if possible, or parse history. | |
# For simplicity with Gradio's streaming and history, we'll re-parse the last user message. | |
# This is not ideal but works for this setup. | |
last_user_turn_display = history[-1][0] | |
# This is a simplified parsing. A more robust way would be to pass | |
# the raw MultimodalTextbox output to `bot` directly. | |
user_text_content = "" | |
user_image_files = [] | |
if isinstance(last_user_turn_display, str): | |
# Check if it's a simple text or a text with image count | |
img_count_match = re.search(r" \((\d+) image\(s\) uploaded\)$", last_user_turn_display) | |
if img_count_match: | |
user_text_content = last_user_turn_display[:img_count_match.start()] | |
# We can't get back the actual file paths from this string alone. | |
# This part needs the raw input from MultimodalTextbox. | |
# For now, we'll assume image_files are passed correctly to `respond` | |
# This means `msg.submit` should pass `msg` directly to `respond`'s `message` param. | |
else: | |
user_text_content = last_user_turn_display | |
# The `msg` (MultimodalTextbox) component's value is what we need for image_files | |
# We assume `msg.value` is implicitly passed or accessible via `request` if Gradio supports it, | |
# or it should be an explicit input to `bot`. | |
# For this implementation, we rely on `msg` being passed to `respond` via the `submit` chain. | |
# The `history` argument to `bot` is for the chatbot display. | |
# The actual call to `respond` will happen via the `msg.submit` chain. | |
# This `bot` function is primarily for updating the chatbot display. | |
history[-1][1] = "" # Clear previous bot response | |
# `respond` is a generator. We need to iterate through its yields. | |
# The `msg` component's value (which includes text and files) is the first argument to `respond`. | |
# We need to ensure that `msg` is correctly passed. | |
# The current `msg.submit` passes `msg` (the component itself) to `user`, then `user`'s output to `bot`. | |
# This is problematic for getting the raw files. | |
# Correct approach: `msg.submit` should pass `msg` (value) to `respond` (or a wrapper). | |
# Let's assume `respond` will be called correctly by the `msg.submit` chain. | |
# This `bot` function will just yield the history updates. | |
# The actual generation is now handled by `msg.submit(...).then(respond, ...)` | |
# This `bot` function is mostly a placeholder in the new structure if `respond` directly yields to chatbot. | |
# However, Gradio's `chatbot.then(bot, ...)` expects `bot` to be the generator. | |
# Re-structuring: `msg.submit` calls `user` to update history for display. | |
# Then, `user`'s output (which is just `history`) is passed to `bot`. | |
# `bot` then calls `respond` with all necessary parameters. | |
# Extract the latest user message components (text and files) | |
# This is tricky because `history` only has the display string. | |
# We need the raw `msg` value. | |
# The `request: gr.Request` can sometimes hold component values if using `gr.Interface`. | |
# For Blocks, it's better to pass `msg` directly. | |
# Let's assume `user_text_content` and `user_image_files` are correctly extracted | |
# from the `msg` component's value when `respond` is called. | |
# The `bot` function here will iterate over what `respond` yields. | |
# The `message` param for `respond` should be the raw output of `msg` | |
# So, `msg` (the component) should be an input to `bot`. | |
# Then `bot` extracts `text` and `files` from `msg.value` (or `msg` if it's already the value). | |
# The `msg.submit` chain needs to be: | |
# msg.submit(fn=user_interaction_handler, inputs=[msg, chatbot, ...other_params...], outputs=[chatbot]) | |
# where user_interaction_handler calls `user` then `respond`. | |
# For now, let's assume `respond` is correctly called by the `msg.submit` chain | |
# and this `bot` function is what updates the chatbot display. | |
# The `inputs` to `bot` in `msg.submit(...).then(bot, inputs=[...])` are crucial. | |
# The `message` and `image_files` for `respond` will come from the `msg` component. | |
# The `history` for `respond` will be `history[:-1]` (all but the current user turn). | |
# This `bot` function is essentially the core of `respond` now. | |
# It needs `msg_value` as an input. | |
# Let's rename this function to reflect it's the main generation logic | |
# and ensure it gets the raw `msg` value. | |
# The Gradio `msg.submit` will call a wrapper that then calls this. | |
# For simplicity, we'll assume `respond` is called correctly by the chain. | |
# This `bot` function is what `chatbot.then(bot, ...)` uses. | |
# The `history` object here is the one managed by Gradio's Chatbot. | |
# `history[-1][0]` is the user's latest displayed message. | |
# `history[-1][1]` is where the bot's response goes. | |
# The `respond` function needs the raw message and files. | |
# The `msg` component itself should be an input to this `bot` function. | |
# Let's adjust the `msg.submit` call later. | |
# For now, this `bot` function is the generator that `chatbot.then()` expects. | |
# It will internally call `respond`. | |
# The `message` and `image_files` for `respond` must be sourced from the `msg` component's value, | |
# not from `history[-1][0]`. | |
# This function signature is what `chatbot.then(bot, ...)` will use. | |
# The `inputs` to this `bot` must be correctly specified in `msg.submit(...).then(bot, inputs=...)`. | |
# `msg_input` should be the value of the `msg` MultimodalTextbox. | |
# Let's assume `msg_input` is correctly passed as the first argument to this `bot` function. | |
# We'll rename `history` to `chatbot_history` to avoid confusion. | |
# The `msg.submit` chain should be: | |
# 1. `user` function: takes `msg_input`, `chatbot_history` -> updates `chatbot_history` for display, returns raw `msg_input` and `chatbot_history[:-1]` for `respond`. | |
# 2. `respond` function: takes raw `msg_input`, `history_for_respond`, and other params -> yields response chunks. | |
# Simpler: `msg.submit` calls `respond_wrapper` which handles history and calls `respond`. | |
# The current structure: `msg.submit` calls `user`, then `bot`. | |
# `user` appends user's input to `chatbot` (history). | |
# `bot` gets this updated `chatbot` (history). | |
# `bot` needs to extract the latest user input (text & files) to pass to `respond`. | |
# This is difficult because `history` only has display strings. | |
# Solution: `msg` (the component's value) must be passed to `bot`. | |
# Let's adjust the `msg.submit` later. For now, assume `message_and_files_input` is passed. | |
# This function's signature for `chatbot.then(bot, ...)`: | |
# bot(chatbot_history, system_msg, ..., msg_input_value) | |
# The `msg_input_value` will be the first argument if we adjust the `inputs` list. | |
# Let's assume the first argument `chatbot_history` is the chatbot's state. | |
# The actual user input (text + files) needs to be passed separately. | |
# The `inputs` to `bot` in the `.then(bot, inputs=[...])` call must include `msg`. | |
# If `respond` is called directly by `msg.submit().then()`, then `respond` itself is the generator. | |
# The `chatbot` component updates based on what `respond` yields. | |
# The current `msg.submit` structure is: | |
# .then(user, [msg, chatbot], [chatbot]) <- `user` updates chatbot with user's message | |
# .then(bot, [chatbot, ...other_params...], [chatbot]) <- `bot` generates response | |
# `bot` needs the raw `msg` value. Let's add `msg` as an input to `bot`. | |
# The `inputs` list for `.then(bot, ...)` will need to include `msg`. | |
# The `message` and `image_files` for `respond` should come from `msg_val` (the value of the msg component) | |
# `history_for_api` should be `chatbot_history[:-1]` | |
# The `chatbot` variable passed to `bot` is the current state of the Chatbot UI. | |
# `chatbot[-1][0]` is the latest user message displayed. | |
# `chatbot[-1][1]` is where the bot's response will be streamed. | |
# We need the raw `msg` value. Let's assume it's passed as an argument to `bot`. | |
# The `inputs` in `.then(bot, inputs=[msg, chatbot, ...])` | |
# The `respond` function will be called with: | |
# - message: text from msg_val | |
# - image_files: files from msg_val | |
# - history: chatbot_history[:-1] (all previous turns) | |
# This `bot` function is the one that `chatbot.then()` will call. | |
# It needs `msg_val` as an input. | |
# The `inputs` for this `bot` function in the Gradio chain will be: | |
# [chatbot, system_message_box, ..., msg] | |
# So, `msg_val` will be the last parameter. | |
msg_val = history.pop('_msg_val_temp_') # Retrieve the raw msg value | |
raw_text_input = msg_val.get("text", "") | |
raw_file_inputs = msg_val.get("files", []) | |
# The history for the API should be all turns *before* the current user input | |
history_for_api = [turn for turn in history[:-1]] # all but the last (current) turn | |
history[-1][1] = "" # Clear placeholder for bot response | |
for chunk in respond( | |
message=raw_text_input, | |
image_files=raw_file_inputs, | |
history=history_for_api, # Pass history *before* current user turn | |
system_message=system_msg, | |
max_tokens=max_tokens_val, | |
temperature=temperature_val, | |
top_p=top_p_val, | |
frequency_penalty=freq_penalty_val, | |
seed=seed_val, | |
provider=provider_val, | |
custom_api_key=api_key_val, | |
custom_model=custom_model_val, | |
selected_model=selected_model_val, # selected_model is now the one from radio | |
model_search_term=search_term_val # Though search_term is not directly used by respond | |
): | |
history[-1][1] = chunk # Stream to the last message's bot part | |
yield history | |
# Event handlers | |
# We need to pass the raw `msg` value to the `bot` function. | |
# We can temporarily store it in the `history` state object if Gradio allows modifying state objects directly. | |
# A cleaner way is to have a single handler function. | |
def combined_user_and_bot(msg_val, chatbot_history, system_msg, max_tokens_val, temperature_val, top_p_val, freq_penalty_val, seed_val, provider_val, api_key_val, custom_model_val, search_term_val, selected_model_val): | |
# 1. Call user to update chatbot display | |
updated_chatbot_history = user(msg_val, chatbot_history) | |
yield updated_chatbot_history # Show user message immediately | |
# 2. Call respond (which is now the core generation logic) | |
# The history for `respond` should be `updated_chatbot_history[:-1]` | |
# Clear placeholder for bot's response in the last turn | |
if updated_chatbot_history and updated_chatbot_history[-1] is not None: | |
updated_chatbot_history[-1][1] = "" | |
history_for_api = updated_chatbot_history[:-1] if updated_chatbot_history else [] | |
for chunk in respond( | |
message=msg_val.get("text", ""), | |
image_files=msg_val.get("files", []), | |
history=history_for_api, | |
system_message=system_msg, | |
max_tokens=max_tokens_val, | |
temperature=temperature_val, | |
top_p=top_p_val, | |
frequency_penalty=freq_penalty_val, | |
seed=seed_val, | |
provider=provider_val, | |
custom_api_key=api_key_val, | |
custom_model=custom_model_val, | |
selected_model=selected_model_val, | |
model_search_term=search_term_val | |
): | |
if updated_chatbot_history and updated_chatbot_history[-1] is not None: | |
updated_chatbot_history[-1][1] = chunk | |
yield updated_chatbot_history | |
msg.submit( | |
combined_user_and_bot, | |
[msg, chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, | |
frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, | |
model_search_box, featured_model_radio], # Pass `msg` (value of MultimodalTextbox) | |
[chatbot] | |
).then( | |
lambda: {"text": "", "files": []}, # Clear inputs after submission | |
None, | |
[msg] | |
) | |
# Connect the model filter to update the radio choices | |
model_search_box.change( | |
fn=filter_models, | |
inputs=model_search_box, | |
outputs=featured_model_radio | |
) | |
print("Model search box change event linked.") | |
# Connect the featured model radio to update the custom model box (if user selects from radio, it populates custom_model_box) | |
featured_model_radio.change( | |
fn=lambda selected_model_from_radio: selected_model_from_radio, # Directly pass the value | |
inputs=featured_model_radio, | |
outputs=custom_model_box # This makes custom_model_box reflect the radio selection | |
# User can then override it by typing. | |
) | |
print("Featured model radio button change event linked.") | |
# Connect the BYOK textbox to validate provider selection | |
byok_textbox.change( | |
fn=validate_provider, | |
inputs=[byok_textbox, provider_radio], | |
outputs=provider_radio | |
) | |
print("BYOK textbox change event linked.") | |
# Also validate provider when the radio changes to ensure consistency | |
provider_radio.change( | |
fn=validate_provider, | |
inputs=[byok_textbox, provider_radio], | |
outputs=provider_radio | |
) | |
print("Provider radio button change event linked.") | |
print("Gradio interface initialized.") | |
if __name__ == "__main__": | |
print("Launching the demo application.") | |
demo.launch(show_api=True, share=True) # Added share=True for easier testing |