File size: 22,156 Bytes
c102ebc b314d94 a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e b314d94 a63d56e c102ebc a63d56e c102ebc b314d94 c102ebc b314d94 c102ebc b314d94 c102ebc b314d94 c102ebc b314d94 a63d56e c102ebc b314d94 c102ebc a63d56e c102ebc a63d56e b314d94 c102ebc a63d56e b314d94 c102ebc a63d56e c102ebc a63d56e b314d94 a63d56e b314d94 a63d56e c102ebc a63d56e b314d94 a63d56e c102ebc a63d56e c102ebc a63d56e b314d94 a63d56e c102ebc a63d56e b314d94 a63d56e c102ebc a63d56e b314d94 a63d56e b314d94 c102ebc b314d94 a63d56e b314d94 c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e b314d94 c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e b314d94 a63d56e c102ebc a63d56e c102ebc a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e b314d94 c102ebc b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 c102ebc a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 a63d56e b314d94 c102ebc a63d56e b314d94 a63d56e b314d94 a63d56e c102ebc b314d94 c102ebc a63d56e c102ebc b314d94 c102ebc b314d94 c102ebc a63d56e c102ebc b314d94 c102ebc a63d56e b314d94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 |
import gradio as gr
import torch
from transformers import pipeline, set_seed
# 导入 AutoPipelineForText2Image 以便兼容不同模型
from diffusers import AutoPipelineForText2Image
import openai
import os
import time
import traceback # For detailed error logging
# ---- Configuration & API Key ----
# Check for OpenAI API Key in Hugging Face Secrets
api_key = os.environ.get("OPENAI_API_KEY")
openai_client = None
openai_available = False
if api_key:
try:
# Starting with openai v1, client instantiation is preferred
openai_client = openai.OpenAI(api_key=api_key)
# Simple test to check if the key is valid (optional, but good)
# openai_client.models.list() # This call might incur small cost/quota usage
openai_available = True
print("OpenAI API key found and client initialized.")
except Exception as e:
print(f"Error initializing OpenAI client: {e}")
print("Proceeding without OpenAI features.")
else:
print("WARNING: OPENAI_API_KEY secret not found. Prompt enhancement via OpenAI is disabled.")
# Force CPU usage
device = "cpu"
print(f"Using device: {device}")
# ---- Model Loading (CPU Focused) ----
# 1. 语音转文本模型 (Whisper) - 加分项
asr_pipeline = None
try:
print("Loading ASR pipeline (Whisper) on CPU...")
# Force CPU usage with device=-1 or device="cpu"
# 使用 fp16 会更快但需要GPU,CPU上用 float32
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device, torch_dtype=torch.float32)
print("ASR pipeline loaded successfully on CPU.")
except Exception as e:
print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
traceback.print_exc() # Print full traceback for debugging
# 2. 文本到图像模型 (Tiny Text-to-Image) - 资源友好模型
image_generator_pipe = None
# 使用资源需求极低的 Tiny Text-to-Image 模型
model_id = "hf-internal-testing/tiny-text-to-image"
try:
print(f"Loading Text-to-Image pipeline ({model_id}) on CPU...")
print("NOTE: Using a very small model for resource efficiency. Image quality will be lower than Stable Diffusion.")
# 使用 AutoPipelineForText2Image 自动识别模型类型
image_generator_pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float32)
image_generator_pipe = image_generator_pipe.to(device)
print(f"Text-to-Image pipeline ({model_id}) loaded successfully on CPU.")
except Exception as e:
print(f"CRITICAL: Could not load Text-to-Image pipeline ({model_id}): {e}. Image generation will fail.")
traceback.print_exc() # Print full traceback for debugging
# Define a dummy object to prevent crashes later if loading failed
class DummyPipe:
def __call__(self, *args, **kwargs):
raise RuntimeError(f"Text-to-Image model failed to load: {e}")
image_generator_pipe = DummyPipe()
# ---- Core Function Definitions ----
# Step 1: Prompt-to-Prompt (using OpenAI API)
def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
"""Uses OpenAI API to enhance the short description."""
if not openai_available or not openai_client:
# Fallback or error if OpenAI key is missing/invalid
print("OpenAI not available. Returning original prompt with modifiers.")
# Basic fallback prompt enhancement
if short_prompt:
return f"{short_prompt}, {style_modifier}, {quality_boost}"
else:
# If short prompt is empty, fallback should also indicate error
raise gr.Error("Input description cannot be empty.")
if not short_prompt:
# Return an error message formatted for Gradio output
raise gr.Error("Input description cannot be empty.")
# Construct the prompt for the OpenAI model
system_message = (
"You are an expert prompt engineer for AI image generation models. "
"Expand the user's short description into a detailed, vivid, and coherent prompt, suitable for smaller, faster text-to-image models. "
"Focus on clear subjects, objects, and main scene elements. "
"Incorporate the requested style and quality keywords naturally, but keep the overall prompt concise enough for smaller models. Avoid conversational text."
# Adjusting guidance for smaller models
)
user_message = (
f"Enhance this description: \"{short_prompt}\". "
f"Style: '{style_modifier}'. Quality: '{quality_boost}'."
)
print(f"Sending request to OpenAI for prompt enhancement: {short_prompt}")
try:
response = openai_client.chat.completions.create(
model="gpt-3.5-turbo", # Cost-effective choice
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
],
temperature=0.7, # Controls creativity vs predictability
max_tokens=100, # Limit output length - reduced for potentially shorter prompts for smaller models
n=1, # Generate one response
stop=None # Let the model decide when to stop
)
enhanced_prompt = response.choices[0].message.content.strip()
print("OpenAI enhancement successful.")
# Basic cleanup: remove potential quotes around the whole response
if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
enhanced_prompt = enhanced_prompt[1:-1]
return enhanced_prompt
except openai.AuthenticationError:
print("OpenAI Authentication Error: Invalid API key?")
raise gr.Error("OpenAI Authentication Error: Check your API key.")
except openai.RateLimitError:
print("OpenAI Rate Limit Error: You've exceeded your quota or rate limit.")
raise gr.Error("OpenAI Error: Rate limit exceeded.")
except openai.APIError as e:
print(f"OpenAI API Error: {e}")
raise gr.Error(f"OpenAI API Error: {e}")
except Exception as e:
print(f"An unexpected error occurred during OpenAI call: {e}")
traceback.print_exc()
raise gr.Error(f"Prompt enhancement failed: {e}")
# Step 2: Prompt-to-Image (CPU)
def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_steps):
"""Generates image using the loaded model on CPU."""
# 检查加载的模型是否是期望的pipeline类型或DummyPipe
if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
# If it's a DummyPipe or None for some reason
if isinstance(image_generator_pipe, DummyPipe):
# DummyPipe will raise its own error when called, so just let it
pass # The call below will raise the intended error
else:
# Handle unexpected case where pipe is not loaded correctly
raise gr.Error("Image generation pipeline is not available (failed to load model).")
if not prompt or "[Error:" in prompt or "Error:" in prompt:
# Check if the prompt itself is an error message from the previous step
raise gr.Error("Cannot generate image due to invalid or missing prompt.")
print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
# Note: Negative prompt and guidance scale might have less impact or behave differently
# on very small models like tiny-text-to-image.
print(f"Negative prompt: {negative_prompt}") # Will likely be ignored by tiny model
print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}") # Steps might be fixed internally by tiny model
start_time = time.time()
try:
# Use torch.inference_mode() or torch.no_grad() for efficiency
with torch.no_grad():
# Seed for reproducibility (optional, but good practice)
# generator = torch.Generator(device=device).manual_seed(int(time.time())) # Tiny model might not use generator param
# Tiny Text-to-Image pipeline call structure might be simpler
# Check model specific documentation if parameters like guidance_scale, num_inference_steps, negative_prompt
# are actually supported. They might be ignored.
# Using a simple call that is generally compatible
output = image_generator_pipe(prompt=prompt) # Tiny model might only take prompt
# The output structure varies between pipelines, assuming it has .images
# if hasattr(output, 'images') and isinstance(output.images, list) and len(output.images) > 0:
# image = output.images[0] # Access the first image
# else:
# # Handle cases where output format is different
# print("Warning: Pipeline output format unexpected. Assuming the output itself is the image.")
# image = output # Assume output is the image if no .images
# Based on tiny-text-to-image, the output is likely a tuple where the first element is a list of images
image = output[0][0] # Access the first image in the first list of the tuple output structure
end_time = time.time()
print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds (using {model_id}).")
return image
except Exception as e:
print(f"Error during image generation on CPU ({model_id}): {e}")
traceback.print_exc()
# Propagate error to Gradio UI
raise gr.Error(f"Image generation failed on CPU ({model_id}): {e}")
# Bonus: Voice-to-Text (CPU)
def transcribe_audio(audio_file_path):
"""Transcribes audio to text using Whisper on CPU."""
if not asr_pipeline:
# This case should ideally be handled by hiding the control, but double-check
return "[Error: ASR model not loaded]", audio_file_path
if audio_file_path is None:
return "", audio_file_path # No audio input
print(f"Transcribing audio file: {audio_file_path} on CPU...")
start_time = time.time()
try:
# Ensure the pipeline uses the correct device (should be CPU based on loading)
# Ensure input is in expected format for Whisper pipeline (filepath or audio array)
transcription = asr_pipeline(audio_file_path)["text"]
end_time = time.time()
print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
print(f"Transcription result: {transcription}")
return transcription, audio_file_path
except Exception as e:
print(f"Error during audio transcription on CPU: {e}")
traceback.print_exc()
# Return error message in the expected tuple format
return f"[Error: Transcription failed: {e}]", audio_file_path
# ---- Gradio Application Flow ----
def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
"""Main function triggered by Gradio button."""
final_text_input = ""
enhanced_prompt = ""
generated_image = None
status_message = "" # To gather status/errors for the prompt box
# 1. Determine Input (Text or Audio)
if input_text and input_text.strip():
final_text_input = input_text.strip()
print(f"Using text input: '{final_text_input}'")
elif audio_file is not None:
print("Processing audio input...")
try:
# Gradio might pass a tuple (samplerate, audio_data) or a filepath depending on type="filepath" vs "numpy"
# transcribe_audio expects a filepath based on the Gradio component config
if isinstance(audio_file, tuple):
# If Gradio gives tuple for some reason, try to save to temp file or adjust transcribe_audio
# Assuming type="filepath" works as expected and passes filepath
audio_filepath_to_transcribe = audio_file[0] # This might be incorrect depending on Gradio version/config
print(f"Warning: Gradio audio input was tuple, attempting to use first element as path: {audio_filepath_to_transcribe}")
else:
audio_filepath_to_transcribe = audio_file # This is expected for type="filepath"
transcribed_text, _ = transcribe_audio(audio_filepath_to_transcribe)
if "[Error:" in transcribed_text:
# Display transcription error clearly
status_message = transcribed_text
print(status_message)
return status_message, None # Return error in prompt field, no image
elif transcribed_text:
final_text_input = transcribed_text
print(f"Using transcribed audio input: '{final_text_input}'")
else:
status_message = "[Error: Audio input received but transcription was empty.]"
print(status_message)
return status_message, None # Return error
except Exception as e:
status_message = f"[Unexpected Audio Transcription Error: {e}]"
print(status_message)
traceback.print_exc()
return status_message, None # Return error
else:
status_message = "[Error: No input provided. Please enter text or record audio.]"
print(status_message)
return status_message, None # Return error
# 2. Enhance Prompt (using OpenAI if available)
if final_text_input:
try:
enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
status_message = enhanced_prompt # Display the prompt initially
print(f"Enhanced prompt: {enhanced_prompt}")
except gr.Error as e:
# Catch Gradio-specific errors from enhancement function
status_message = f"[Prompt Enhancement Error: {e}]"
print(status_message)
# Return the error, no image generation attempt
return status_message, None
except Exception as e:
# Catch any other unexpected errors
status_message = f"[Unexpected Prompt Enhancement Error: {e}]"
print(status_message)
traceback.print_exc()
return status_message, None
# 3. Generate Image (if prompt is valid)
# Check if the enhanced prompt step resulted in an error message
if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
try:
# Show "Generating..." message while waiting
gr.Info(f"Starting image generation on CPU using {model_id}. This should be fast but quality is low.")
generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
gr.Info("Image generation complete!")
except gr.Error as e:
# Catch Gradio errors from generation function
# Prepend original enhanced prompt to the error message for context
status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
print(f"Image Generation Error: {e}")
generated_image = None # Ensure image is None on error
except Exception as e:
# Catch any other unexpected errors
status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
print(f"Unexpected Image Generation Error: {e}")
traceback.print_exc()
generated_image = None # Ensure image is None on error
else:
# If prompt enhancement failed, status_message already contains the error
# In this case, we just return the existing status_message and None image
print("Skipping image generation due to prompt enhancement failure.")
# 4. Return results to Gradio UI
# Return the status message (enhanced prompt or error) and the image (or None if error)
return status_message, generated_image
# ---- Gradio Interface Construction ----
style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
# Tiny model is very fast, steps/guidance might be ignored or have less effect
# Keep sliders but note their limited impact on this specific model
default_steps = 10 # Tiny model often uses few steps internally
max_steps = 20 # Limit max steps as they might not matter much
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# AI Image Generator (Resource-Friendly CPU Version)")
gr.Markdown(
"**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
f"to create a detailed prompt, then generates an image using a **small, fast model ({model_id}) on the CPU**."
)
# Add specific warning about image quality for the tiny model
gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Note: Using a small model for compatibility. Image quality and resolution will be significantly lower than models like Stable Diffusion.</p>")
# Display OpenAI availability status
if not openai_available:
gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")
else:
gr.Markdown("**Note:** OpenAI API key found. Prompt will be enhanced using OpenAI.")
# Display Model loading status
if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
gr.Markdown(f"**CRITICAL:** Image generation model ({model_id}) failed to load. Image generation is disabled. Check logs.")
with gr.Row():
with gr.Column(scale=1):
# --- Inputs ---
inp_text = gr.Textbox(label="Enter short description", placeholder="e.g., A cute robot drinking coffee on Mars")
# Only show Audio input if ASR model loaded successfully
if asr_pipeline:
inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
else:
gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
# Using gr.State as a placeholder that holds None
inp_audio = gr.State(None)
# --- Controls (Step 3 requirements met) ---
# Note: These controls might have limited effect on the small model
gr.Markdown("*(Optional controls - Note: These may have limited or no effect on the small model used)*")
# Control 1: Dropdown
inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic", interactive=True)
# Control 2: Radio
inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed", interactive=True)
# Control 3: Textbox (Negative Prompt)
inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark", interactive=True)
# Control 4: Slider (Guidance Scale)
inp_guidance = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=3.0, label="Guidance Scale (CFG)", interactive=True) # Lower default for small model
# Control 5: Slider (Inference Steps) - Reduced max/default
inp_steps = gr.Slider(minimum=1, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})", interactive=True)
# --- Action Button ---
# Disable button if model failed to load
btn_generate = gr.Button("Generate Image", variant="primary", interactive=isinstance(image_generator_pipe, AutoPipelineForText2Image))
with gr.Column(scale=1):
# --- Outputs ---
out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # Show prompt or error status here
out_image = gr.Image(label="Generated Image", type="pil", show_label=True) # Ensure label is shown
# --- Event Handling ---
# Define inputs list carefully, handling potentially invisible audio input
inputs_list = [inp_text]
if asr_pipeline:
inputs_list.append(inp_audio)
else:
inputs_list.append(inp_audio) # Pass the gr.State(None) placeholder
inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
# Link button click to processing function
btn_generate.click(
fn=process_input,
inputs=inputs_list,
outputs=[out_prompt, out_image]
)
# Clear text input if audio is used (only if ASR is available)
if asr_pipeline:
def clear_text_on_audio_change(audio_data):
# Check if audio_data is not None or empty (depending on how Gradio signals recording)
if audio_data is not None:
print("Audio input detected, clearing text box.")
return "" # Clear text box
# If audio_data becomes None (e.g., recording cleared), don't clear text
return gr.update()
# .change event fires when the value changes, including becoming None if cleared
inp_audio.change(fn=clear_text_on_audio_change, inputs=inp_audio, outputs=inp_text, api_name="clear_text_on_audio")
# ---- Application Launch ----
if __name__ == "__main__":
# Final check before launch
if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
print("\n" + "="*50)
print("CRITICAL WARNING:")
print(f"Image generation model ({model_id}) failed to load during startup.")
print("The Gradio UI will launch, but the 'Generate Image' button will be disabled.")
print("Check the logs above for the specific model loading error.")
print("="*50 + "\n")
# Launch the Gradio app
# Running on 0.0.0.0 is necessary for Hugging Face Spaces
demo.launch(share=False, server_name="0.0.0.0", server_port=7860) |