sohanAI's picture
Upload 9 files
76d118b verified
raw
history blame
8.51 kB
import os
import sys
import random
import torch
import pickle
import numpy as np
from PIL import Image
import torch.nn.functional as F
import gradio as gr
from omegaconf import OmegaConf
from scipy.stats import truncnorm
import subprocess
# First run the download_models.py script if models haven't been downloaded
if not os.path.exists('data/state_epoch_1220.pth') or not os.path.exists('data/text_encoder200.pth'):
print("Downloading necessary model files...")
try:
subprocess.check_call([sys.executable, "download_models.py"])
except subprocess.CalledProcessError as e:
print(f"Error downloading models: {e}")
print("Please run download_models.py manually before starting the app.")
# Add the code directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "DF-GAN/code"))
# Import necessary modules from the DF-GAN code
from models.DAMSM import RNN_ENCODER
from models.GAN import NetG
# Utility functions
def load_model_weights(model, weights, multi_gpus=False, train=False):
"""Load model weights with proper handling of module prefix"""
if list(weights.keys())[0].find('module')==-1:
pretrained_with_multi_gpu = False
else:
pretrained_with_multi_gpu = True
if (multi_gpus==False) or (train==False):
if pretrained_with_multi_gpu:
state_dict = {
key[7:]: value
for key, value in weights.items()
}
else:
state_dict = weights
else:
state_dict = weights
model.load_state_dict(state_dict)
return model
def get_tokenizer():
"""Get NLTK tokenizer"""
from nltk.tokenize import RegexpTokenizer
tokenizer = RegexpTokenizer(r'\w+')
return tokenizer
def truncated_noise(batch_size=1, dim_z=100, truncation=1.0, seed=None):
"""Generate truncated noise"""
state = None if seed is None else np.random.RandomState(seed)
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
return truncation * values
def tokenize_and_build_captions(input_text, wordtoix):
"""Tokenize text and convert to indices using wordtoix mapping"""
tokenizer = get_tokenizer()
tokens = tokenizer.tokenize(input_text.lower())
cap = []
for t in tokens:
t = t.encode('ascii', 'ignore').decode('ascii')
if len(t) > 0 and t in wordtoix:
cap.append(wordtoix[t])
# Create padded array for the caption
max_len = 18 # As defined in the bird.yml
cap_array = np.zeros(max_len, dtype='int64')
cap_len = len(cap)
if cap_len <= max_len:
cap_array[:cap_len] = cap
else:
# Truncate if too long
cap_array = cap[:max_len]
cap_len = max_len
return cap_array, cap_len
def encode_caption(caption, caption_len, text_encoder, device):
"""Encode caption using text encoder"""
with torch.no_grad():
caption = torch.tensor([caption]).to(device)
caption_len = torch.tensor([caption_len]).to(device)
hidden = text_encoder.init_hidden(1)
_, sent_emb = text_encoder(caption, caption_len, hidden)
return sent_emb
def save_img(img_tensor):
"""Convert image tensor to PIL Image"""
im = img_tensor.data.cpu().numpy()
# [-1, 1] --> [0, 255]
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)
return im
# Load configuration
config = {
'z_dim': 100,
'cond_dim': 256,
'imsize': 256,
'nf': 32,
'ch_size': 3,
'truncation': True,
'trunc_rate': 0.88,
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load vocab and models
def load_models():
# Load vocabulary
with open('data/captions_DAMSM.pickle', 'rb') as f:
x = pickle.load(f)
wordtoix = x[3]
ixtoword = x[2]
del x
# Initialize text encoder
text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim'])
text_encoder_path = 'data/text_encoder200.pth'
state_dict = torch.load(text_encoder_path, map_location='cpu')
text_encoder = load_model_weights(text_encoder, state_dict)
text_encoder.to(device)
for p in text_encoder.parameters():
p.requires_grad = False
text_encoder.eval()
# Initialize generator
netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size'])
netG_path = 'data/state_epoch_1220.pth'
state_dict = torch.load(netG_path, map_location='cpu')
netG = load_model_weights(netG, state_dict['model']['netG'])
netG.to(device)
netG.eval()
return wordtoix, ixtoword, text_encoder, netG
wordtoix, ixtoword, text_encoder, netG = load_models()
def generate_image(text_input, num_images=1, seed=None):
"""Generate images from text description"""
if not text_input.strip():
return [None] * num_images
cap_array, cap_len = tokenize_and_build_captions(text_input, wordtoix)
if cap_len == 0:
return [Image.new('RGB', (256, 256), color='red')] * num_images
sent_emb = encode_caption(cap_array, cap_len, text_encoder, device)
# Set random seed if provided
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Generate multiple images if requested
result_images = []
with torch.no_grad():
for _ in range(num_images):
# Generate noise
if config['truncation']:
noise = truncated_noise(1, config['z_dim'], config['trunc_rate'])
noise = torch.tensor(noise, dtype=torch.float).to(device)
else:
noise = torch.randn(1, config['z_dim']).to(device)
# Generate image
fake_img = netG(noise, sent_emb)
img = save_img(fake_img[0])
result_images.append(img)
return result_images
# Create Gradio interface
def generate_images_interface(text, num_images, random_seed):
seed = int(random_seed) if random_seed else None
return generate_image(text, num_images, seed)
with gr.Blocks(title="Bird Image Generator") as demo:
gr.Markdown("# Bird Image Generator using DF-GAN")
gr.Markdown("Enter a description of a bird and the model will generate corresponding images.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Bird Description",
placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')",
lines=3
)
num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images")
seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results")
submit_btn = gr.Button("Generate Image")
with gr.Column():
image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto")
submit_btn.click(
fn=generate_images_interface,
inputs=[text_input, num_images, seed],
outputs=image_output
)
gr.Markdown("## Example Descriptions")
example_descriptions = [
"this bird has an orange bill, a white belly and white eyebrows",
"a small bird with a red head, breast, and belly and black wings",
"this bird is yellow with black and has a long, pointy beak",
"this bird is white in color, and has a orange beak"
]
gr.Examples(
examples=[[desc, 1, ""] for desc in example_descriptions],
inputs=[text_input, num_images, seed],
outputs=image_output,
fn=generate_images_interface
)
# Launch the app with appropriate configurations for Hugging Face Spaces
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Bind to all network interfaces
share=False, # Don't use share links
favicon_path="https://raw.githubusercontent.com/tobran/DF-GAN/main/framework.png"
)