Spaces:
Sleeping
Sleeping
File size: 8,507 Bytes
76d118b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import os
import sys
import random
import torch
import pickle
import numpy as np
from PIL import Image
import torch.nn.functional as F
import gradio as gr
from omegaconf import OmegaConf
from scipy.stats import truncnorm
import subprocess
# First run the download_models.py script if models haven't been downloaded
if not os.path.exists('data/state_epoch_1220.pth') or not os.path.exists('data/text_encoder200.pth'):
print("Downloading necessary model files...")
try:
subprocess.check_call([sys.executable, "download_models.py"])
except subprocess.CalledProcessError as e:
print(f"Error downloading models: {e}")
print("Please run download_models.py manually before starting the app.")
# Add the code directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "DF-GAN/code"))
# Import necessary modules from the DF-GAN code
from models.DAMSM import RNN_ENCODER
from models.GAN import NetG
# Utility functions
def load_model_weights(model, weights, multi_gpus=False, train=False):
"""Load model weights with proper handling of module prefix"""
if list(weights.keys())[0].find('module')==-1:
pretrained_with_multi_gpu = False
else:
pretrained_with_multi_gpu = True
if (multi_gpus==False) or (train==False):
if pretrained_with_multi_gpu:
state_dict = {
key[7:]: value
for key, value in weights.items()
}
else:
state_dict = weights
else:
state_dict = weights
model.load_state_dict(state_dict)
return model
def get_tokenizer():
"""Get NLTK tokenizer"""
from nltk.tokenize import RegexpTokenizer
tokenizer = RegexpTokenizer(r'\w+')
return tokenizer
def truncated_noise(batch_size=1, dim_z=100, truncation=1.0, seed=None):
"""Generate truncated noise"""
state = None if seed is None else np.random.RandomState(seed)
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
return truncation * values
def tokenize_and_build_captions(input_text, wordtoix):
"""Tokenize text and convert to indices using wordtoix mapping"""
tokenizer = get_tokenizer()
tokens = tokenizer.tokenize(input_text.lower())
cap = []
for t in tokens:
t = t.encode('ascii', 'ignore').decode('ascii')
if len(t) > 0 and t in wordtoix:
cap.append(wordtoix[t])
# Create padded array for the caption
max_len = 18 # As defined in the bird.yml
cap_array = np.zeros(max_len, dtype='int64')
cap_len = len(cap)
if cap_len <= max_len:
cap_array[:cap_len] = cap
else:
# Truncate if too long
cap_array = cap[:max_len]
cap_len = max_len
return cap_array, cap_len
def encode_caption(caption, caption_len, text_encoder, device):
"""Encode caption using text encoder"""
with torch.no_grad():
caption = torch.tensor([caption]).to(device)
caption_len = torch.tensor([caption_len]).to(device)
hidden = text_encoder.init_hidden(1)
_, sent_emb = text_encoder(caption, caption_len, hidden)
return sent_emb
def save_img(img_tensor):
"""Convert image tensor to PIL Image"""
im = img_tensor.data.cpu().numpy()
# [-1, 1] --> [0, 255]
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)
return im
# Load configuration
config = {
'z_dim': 100,
'cond_dim': 256,
'imsize': 256,
'nf': 32,
'ch_size': 3,
'truncation': True,
'trunc_rate': 0.88,
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load vocab and models
def load_models():
# Load vocabulary
with open('data/captions_DAMSM.pickle', 'rb') as f:
x = pickle.load(f)
wordtoix = x[3]
ixtoword = x[2]
del x
# Initialize text encoder
text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim'])
text_encoder_path = 'data/text_encoder200.pth'
state_dict = torch.load(text_encoder_path, map_location='cpu')
text_encoder = load_model_weights(text_encoder, state_dict)
text_encoder.to(device)
for p in text_encoder.parameters():
p.requires_grad = False
text_encoder.eval()
# Initialize generator
netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size'])
netG_path = 'data/state_epoch_1220.pth'
state_dict = torch.load(netG_path, map_location='cpu')
netG = load_model_weights(netG, state_dict['model']['netG'])
netG.to(device)
netG.eval()
return wordtoix, ixtoword, text_encoder, netG
wordtoix, ixtoword, text_encoder, netG = load_models()
def generate_image(text_input, num_images=1, seed=None):
"""Generate images from text description"""
if not text_input.strip():
return [None] * num_images
cap_array, cap_len = tokenize_and_build_captions(text_input, wordtoix)
if cap_len == 0:
return [Image.new('RGB', (256, 256), color='red')] * num_images
sent_emb = encode_caption(cap_array, cap_len, text_encoder, device)
# Set random seed if provided
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Generate multiple images if requested
result_images = []
with torch.no_grad():
for _ in range(num_images):
# Generate noise
if config['truncation']:
noise = truncated_noise(1, config['z_dim'], config['trunc_rate'])
noise = torch.tensor(noise, dtype=torch.float).to(device)
else:
noise = torch.randn(1, config['z_dim']).to(device)
# Generate image
fake_img = netG(noise, sent_emb)
img = save_img(fake_img[0])
result_images.append(img)
return result_images
# Create Gradio interface
def generate_images_interface(text, num_images, random_seed):
seed = int(random_seed) if random_seed else None
return generate_image(text, num_images, seed)
with gr.Blocks(title="Bird Image Generator") as demo:
gr.Markdown("# Bird Image Generator using DF-GAN")
gr.Markdown("Enter a description of a bird and the model will generate corresponding images.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Bird Description",
placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')",
lines=3
)
num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images")
seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results")
submit_btn = gr.Button("Generate Image")
with gr.Column():
image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto")
submit_btn.click(
fn=generate_images_interface,
inputs=[text_input, num_images, seed],
outputs=image_output
)
gr.Markdown("## Example Descriptions")
example_descriptions = [
"this bird has an orange bill, a white belly and white eyebrows",
"a small bird with a red head, breast, and belly and black wings",
"this bird is yellow with black and has a long, pointy beak",
"this bird is white in color, and has a orange beak"
]
gr.Examples(
examples=[[desc, 1, ""] for desc in example_descriptions],
inputs=[text_input, num_images, seed],
outputs=image_output,
fn=generate_images_interface
)
# Launch the app with appropriate configurations for Hugging Face Spaces
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Bind to all network interfaces
share=False, # Don't use share links
favicon_path="https://raw.githubusercontent.com/tobran/DF-GAN/main/framework.png"
) |