import os import sys import random import torch import pickle import numpy as np from PIL import Image import torch.nn.functional as F import gradio as gr from omegaconf import OmegaConf from scipy.stats import truncnorm import subprocess import traceback import time # Create a flag to track model loading status models_loaded_successfully = False # First run the download_models.py script if models haven't been downloaded if not os.path.exists('data/state_epoch_1220.pth') or not os.path.exists('data/text_encoder200.pth') or not os.path.exists('data/captions_DAMSM.pickle'): print("Downloading necessary model files...") try: subprocess.check_call([sys.executable, "download_models.py"]) except subprocess.CalledProcessError as e: print(f"Error downloading models: {e}") print("Please check the error message above. The application will attempt to continue with fallback settings.") # Setup system paths try: # Add the code directory to the Python path sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "DF-GAN/code")) # Import necessary modules from the DF-GAN code from models.DAMSM import RNN_ENCODER from models.GAN import NetG except ImportError as e: print(f"Error importing required modules: {e}") print("The application may not function correctly.") # Utility functions def load_model_weights(model, weights, multi_gpus=False, train=False): """Load model weights with proper handling of module prefix""" try: if list(weights.keys())[0].find('module')==-1: pretrained_with_multi_gpu = False else: pretrained_with_multi_gpu = True if (multi_gpus==False) or (train==False): if pretrained_with_multi_gpu: state_dict = { key[7:]: value for key, value in weights.items() } else: state_dict = weights else: state_dict = weights model.load_state_dict(state_dict) except Exception as e: print(f"Error loading model weights: {e}") print("Using model with random weights instead.") return model def get_tokenizer(): """Get NLTK tokenizer""" from nltk.tokenize import RegexpTokenizer tokenizer = RegexpTokenizer(r'\w+') return tokenizer def truncated_noise(batch_size=1, dim_z=100, truncation=1.0, seed=None): """Generate truncated noise""" state = None if seed is None else np.random.RandomState(seed) values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) return truncation * values def tokenize_and_build_captions(input_text, wordtoix): """Tokenize text and convert to indices using wordtoix mapping""" tokenizer = get_tokenizer() tokens = tokenizer.tokenize(input_text.lower()) cap = [] for t in tokens: t = t.encode('ascii', 'ignore').decode('ascii') if len(t) > 0 and t in wordtoix: cap.append(wordtoix[t]) # Create padded array for the caption max_len = 18 # As defined in the bird.yml cap_array = np.zeros(max_len, dtype='int64') cap_len = len(cap) if cap_len <= max_len: cap_array[:cap_len] = cap else: # Truncate if too long cap_array = cap[:max_len] cap_len = max_len return cap_array, cap_len def encode_caption(caption, caption_len, text_encoder, device): """Encode caption using text encoder""" try: with torch.no_grad(): caption = torch.tensor([caption]).to(device) caption_len = torch.tensor([caption_len]).to(device) hidden = text_encoder.init_hidden(1) _, sent_emb = text_encoder(caption, caption_len, hidden) return sent_emb except Exception as e: print(f"Error encoding caption: {e}") # Return a random embedding as fallback return torch.randn(1, 256).to(device) def save_img(img_tensor): """Convert image tensor to PIL Image""" try: im = img_tensor.data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) return im except Exception as e: print(f"Error converting image tensor to PIL Image: {e}") # Return a red placeholder image as fallback return Image.new('RGB', (256, 256), color='red') # Load configuration config = { 'z_dim': 100, 'cond_dim': 256, 'imsize': 256, 'nf': 32, 'ch_size': 3, 'truncation': True, 'trunc_rate': 0.88, } # Determine device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Global variables for models wordtoix = {} ixtoword = {} text_encoder = None netG = None models_loaded = False # Load vocab and models def load_models(): global wordtoix, ixtoword, text_encoder, netG, models_loaded, models_loaded_successfully try: # Load vocabulary if os.path.exists('data/captions_DAMSM.pickle'): with open('data/captions_DAMSM.pickle', 'rb') as f: x = pickle.load(f) wordtoix = x[3] ixtoword = x[2] del x else: print("Warning: captions_DAMSM.pickle not found. Using fallback vocabulary.") # Fallback vocabulary wordtoix = {"the": 1, "bird": 2, "is": 3, "a": 4, "with": 5, "and": 6, "red": 7, "black": 8, "yellow": 9} ixtoword = {v: k for k, v in wordtoix.items()} # Initialize text encoder text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim']) text_encoder_path = 'data/text_encoder200.pth' if os.path.exists(text_encoder_path): state_dict = torch.load(text_encoder_path, map_location='cpu') text_encoder = load_model_weights(text_encoder, state_dict) else: print("Warning: text_encoder200.pth not found. Using random weights.") text_encoder.to(device) for p in text_encoder.parameters(): p.requires_grad = False text_encoder.eval() # Initialize generator netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size']) netG_path = 'data/state_epoch_1220.pth' if os.path.exists(netG_path): state_dict = torch.load(netG_path, map_location='cpu') if 'model' in state_dict and 'netG' in state_dict['model']: netG = load_model_weights(netG, state_dict['model']['netG']) models_loaded_successfully = True else: print("Warning: state_epoch_1220.pth has unexpected format. Using random weights.") else: print("Warning: state_epoch_1220.pth not found. Using random weights.") netG.to(device) netG.eval() models_loaded = True return wordtoix, ixtoword, text_encoder, netG except Exception as e: print(f"Error loading models: {e}") traceback.print_exc() print("Using fallback models instead.") # Fallback vocabulary wordtoix = {"the": 1, "bird": 2, "is": 3, "a": 4, "with": 5, "and": 6, "red": 7, "black": 8, "yellow": 9} ixtoword = {v: k for k, v in wordtoix.items()} # Create fallback models try: text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim']).to(device) netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size']).to(device) models_loaded = False except Exception as e2: print(f"Failed to create fallback models: {e2}") return wordtoix, ixtoword, text_encoder, netG # Try to load the models try: wordtoix, ixtoword, text_encoder, netG = load_models() except Exception as e: print(f"Error during model loading: {e}") print("The application will attempt to continue but may not function correctly.") def generate_image(text_input, num_images=1, seed=None): """Generate images from text description""" if not text_input.strip(): return [Image.new('RGB', (256, 256), color='lightgray')] * num_images try: cap_array, cap_len = tokenize_and_build_captions(text_input, wordtoix) if cap_len == 0: return [Image.new('RGB', (256, 256), color='red')] * num_images sent_emb = encode_caption(cap_array, cap_len, text_encoder, device) # Set random seed if provided if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # Generate multiple images if requested result_images = [] with torch.no_grad(): for _ in range(num_images): # Generate noise if config['truncation']: noise = truncated_noise(1, config['z_dim'], config['trunc_rate']) noise = torch.tensor(noise, dtype=torch.float).to(device) else: noise = torch.randn(1, config['z_dim']).to(device) # Generate image try: fake_img = netG(noise, sent_emb) img = save_img(fake_img[0]) result_images.append(img) except Exception as e: print(f"Error generating image: {e}") # Return a placeholder image as fallback img = Image.new('RGB', (256, 256), color=(255, 200, 200)) result_images.append(img) return result_images except Exception as e: print(f"Error in generate_image: {e}") traceback.print_exc() return [Image.new('RGB', (256, 256), color='orange')] * num_images # Create a simple message for model loading status model_status = "✅ Models loaded successfully" if models_loaded_successfully else "⚠️ Using fallback models - images may not look good" # Function to render error page if needed def serve_error_page(): if os.path.exists('error_page.html'): with open('error_page.html', 'r') as f: return f.read() else: return "
The application failed to load the required models.
" # Create Gradio interface def generate_images_interface(text, num_images, random_seed): seed = int(random_seed) if random_seed and random_seed.strip().isdigit() else None return generate_image(text, num_images, seed) # Create the Gradio interface with gr.Blocks(title="Bird Image Generator") as demo: if models_loaded_successfully: # Normal interface when models loaded successfully gr.Markdown("# Bird Image Generator using DF-GAN") gr.Markdown("Enter a description of a bird and the model will generate corresponding images.") gr.Markdown(f"**Model Status:** {model_status}") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Bird Description", placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')", lines=3 ) num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images") seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results") submit_btn = gr.Button("Generate Image") with gr.Column(): image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto") submit_btn.click( fn=generate_images_interface, inputs=[text_input, num_images, seed], outputs=image_output ) gr.Markdown("## Example Descriptions") example_descriptions = [ "this bird has an orange bill, a white belly and white eyebrows", "a small bird with a red head, breast, and belly and black wings", "this bird is yellow with black and has a long, pointy beak", "this bird is white in color, and has a orange beak" ] gr.Examples( examples=[[desc, 1, ""] for desc in example_descriptions], inputs=[text_input, num_images, seed], outputs=image_output, fn=generate_images_interface ) else: # Modified interface with warning when models failed to load gr.Markdown("# ⚠️ Bird Image Generator - Limited Functionality") gr.Markdown("The pre-trained models could not be loaded correctly. The application will run with randomly initialized models.") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Bird Description", placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')", lines=3 ) num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images") seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results") submit_btn = gr.Button("Generate Image (Results will be random shapes)") with gr.Column(): image_output = gr.Gallery(label="Generated Images (Random)").style(grid=2, height="auto") submit_btn.click( fn=generate_images_interface, inputs=[text_input, num_images, seed], outputs=image_output ) gr.Markdown(""" ### Model Loading Error The application encountered an error while loading the pre-trained models. This could be due to: 1. Network connectivity issues 2. The model hosting service might be temporarily unavailable 3. The model files might have been moved or deleted Please try refreshing the page or contact the Space owner if the issue persists. """) # Launch the app with appropriate configurations for Hugging Face Spaces if __name__ == "__main__": # Wait a moment before starting to make sure all logs are printed time.sleep(1) demo.launch( server_name="0.0.0.0", # Bind to all network interfaces share=False, # Don't use share links favicon_path="https://raw.githubusercontent.com/tobran/DF-GAN/main/framework.png" )