sohanAI commited on
Commit
76d118b
·
verified ·
1 Parent(s): 65125c9

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitignore +36 -0
  2. .huggingface-space +9 -0
  3. README-HF.md +31 -0
  4. README.md +34 -13
  5. app.py +239 -0
  6. demo.ipynb +1 -0
  7. download_models.py +56 -0
  8. requirements.txt +16 -0
  9. startup.sh +10 -0
.gitignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+
22
+ # Jupyter Notebook
23
+ .ipynb_checkpoints
24
+
25
+ # Data directories
26
+ data/
27
+ DF-GAN/
28
+
29
+ # Model files
30
+ *.pth
31
+ *.pickle
32
+ *.npz
33
+
34
+ # Generated images
35
+ samples/
36
+ .DS_Store
.huggingface-space ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ title: DF-GAN Bird Image Generator
2
+ emoji: 🐦
3
+ colorFrom: blue
4
+ colorTo: purple
5
+ sdk: gradio
6
+ sdk_version: 3.50.0
7
+ app_file: app.py
8
+ pinned: false
9
+ license: cc-by-nc-sa-4.0
README-HF.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DF-GAN Bird Image Generator 🐦
2
+
3
+ This Hugging Face Space demonstrates the [DF-GAN model](https://arxiv.org/abs/2008.05865) for generating bird images from text descriptions.
4
+
5
+ ## How It Works
6
+
7
+ 1. Enter a text description of a bird
8
+ 2. Select how many images you want to generate (1-4)
9
+ 3. Optionally add a random seed for reproducible results
10
+ 4. Click "Generate Image"
11
+ 5. The model will generate realistic bird images based on your description
12
+
13
+ ## Example Descriptions
14
+
15
+ Try these example descriptions:
16
+ - "this bird has an orange bill, a white belly and white eyebrows"
17
+ - "a small bird with a red head, breast, and belly and black wings"
18
+ - "this bird is yellow with black and has a long, pointy beak"
19
+ - "this is a grey bodied bird with light grey wings and a white breast"
20
+
21
+ ## About the Model
22
+
23
+ The DF-GAN (Deep Fusion GAN) model is a text-to-image synthesis model introduced in the paper "DF-GAN: A Simple and Effective Baseline for Text-to-Image Synthesis" (CVPR 2022). This demo uses the pre-trained bird model that was trained on the CUB-200-2011 dataset.
24
+
25
+ This demo runs on CPU, so image generation may take a few seconds.
26
+
27
+ ## Credits
28
+
29
+ This Space uses the official implementation of DF-GAN from [tobran/DF-GAN](https://github.com/tobran/DF-GAN).
30
+
31
+ Made with ❤️ by [Your Name]
README.md CHANGED
@@ -1,13 +1,34 @@
1
- ---
2
- title: Df Gan Text To Image
3
- emoji: 🐨
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.23.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: DF-GAN Text to Image Generation
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DF-GAN Bird Image Generator
2
+
3
+ This application uses the DF-GAN (Deep Fusion GAN) model to generate bird images based on text descriptions. Just enter a description of a bird, and the model will generate a realistic image that matches your description.
4
+
5
+ ## About the Model
6
+
7
+ This application uses the pre-trained bird model from the [DF-GAN: A Simple and Effective Baseline for Text-to-Image Synthesis](https://arxiv.org/abs/2008.05865) paper (CVPR 2022). DF-GAN is a text-to-image synthesis model that can generate high-quality images from textual descriptions.
8
+
9
+ ## How to Use
10
+
11
+ 1. Enter a description of a bird in the text box (e.g., "a yellow bird with a black head")
12
+ 2. Choose how many images you want to generate (1-4)
13
+ 3. Optionally, set a random seed for reproducible results
14
+ 4. Click "Generate Image" button
15
+ 5. View the generated bird images that match your description
16
+
17
+ ## Examples
18
+
19
+ Try these example descriptions:
20
+ - "this bird has an orange bill, a white belly and white eyebrows"
21
+ - "a small bird with a red head, breast, and belly and black wings"
22
+ - "this bird is yellow with black and has a long, pointy beak"
23
+ - "this bird is white in color, and has a orange beak"
24
+
25
+ ## Implementation Details
26
+
27
+ This application uses the following components:
28
+ - DF-GAN architecture for text-to-image synthesis
29
+ - DAMSM text encoder for embedding text descriptions
30
+ - Gradio for the web interface
31
+
32
+ ## Credits
33
+
34
+ This implementation is based on the official DF-GAN repository: [tobran/DF-GAN](https://github.com/tobran/DF-GAN)
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import random
4
+ import torch
5
+ import pickle
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch.nn.functional as F
9
+ import gradio as gr
10
+ from omegaconf import OmegaConf
11
+ from scipy.stats import truncnorm
12
+ import subprocess
13
+
14
+ # First run the download_models.py script if models haven't been downloaded
15
+ if not os.path.exists('data/state_epoch_1220.pth') or not os.path.exists('data/text_encoder200.pth'):
16
+ print("Downloading necessary model files...")
17
+ try:
18
+ subprocess.check_call([sys.executable, "download_models.py"])
19
+ except subprocess.CalledProcessError as e:
20
+ print(f"Error downloading models: {e}")
21
+ print("Please run download_models.py manually before starting the app.")
22
+
23
+ # Add the code directory to the Python path
24
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "DF-GAN/code"))
25
+
26
+ # Import necessary modules from the DF-GAN code
27
+ from models.DAMSM import RNN_ENCODER
28
+ from models.GAN import NetG
29
+
30
+ # Utility functions
31
+ def load_model_weights(model, weights, multi_gpus=False, train=False):
32
+ """Load model weights with proper handling of module prefix"""
33
+ if list(weights.keys())[0].find('module')==-1:
34
+ pretrained_with_multi_gpu = False
35
+ else:
36
+ pretrained_with_multi_gpu = True
37
+
38
+ if (multi_gpus==False) or (train==False):
39
+ if pretrained_with_multi_gpu:
40
+ state_dict = {
41
+ key[7:]: value
42
+ for key, value in weights.items()
43
+ }
44
+ else:
45
+ state_dict = weights
46
+ else:
47
+ state_dict = weights
48
+
49
+ model.load_state_dict(state_dict)
50
+ return model
51
+
52
+ def get_tokenizer():
53
+ """Get NLTK tokenizer"""
54
+ from nltk.tokenize import RegexpTokenizer
55
+ tokenizer = RegexpTokenizer(r'\w+')
56
+ return tokenizer
57
+
58
+ def truncated_noise(batch_size=1, dim_z=100, truncation=1.0, seed=None):
59
+ """Generate truncated noise"""
60
+ state = None if seed is None else np.random.RandomState(seed)
61
+ values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
62
+ return truncation * values
63
+
64
+ def tokenize_and_build_captions(input_text, wordtoix):
65
+ """Tokenize text and convert to indices using wordtoix mapping"""
66
+ tokenizer = get_tokenizer()
67
+ tokens = tokenizer.tokenize(input_text.lower())
68
+ cap = []
69
+ for t in tokens:
70
+ t = t.encode('ascii', 'ignore').decode('ascii')
71
+ if len(t) > 0 and t in wordtoix:
72
+ cap.append(wordtoix[t])
73
+
74
+ # Create padded array for the caption
75
+ max_len = 18 # As defined in the bird.yml
76
+ cap_array = np.zeros(max_len, dtype='int64')
77
+ cap_len = len(cap)
78
+ if cap_len <= max_len:
79
+ cap_array[:cap_len] = cap
80
+ else:
81
+ # Truncate if too long
82
+ cap_array = cap[:max_len]
83
+ cap_len = max_len
84
+
85
+ return cap_array, cap_len
86
+
87
+ def encode_caption(caption, caption_len, text_encoder, device):
88
+ """Encode caption using text encoder"""
89
+ with torch.no_grad():
90
+ caption = torch.tensor([caption]).to(device)
91
+ caption_len = torch.tensor([caption_len]).to(device)
92
+ hidden = text_encoder.init_hidden(1)
93
+ _, sent_emb = text_encoder(caption, caption_len, hidden)
94
+ return sent_emb
95
+
96
+ def save_img(img_tensor):
97
+ """Convert image tensor to PIL Image"""
98
+ im = img_tensor.data.cpu().numpy()
99
+ # [-1, 1] --> [0, 255]
100
+ im = (im + 1.0) * 127.5
101
+ im = im.astype(np.uint8)
102
+ im = np.transpose(im, (1, 2, 0))
103
+ im = Image.fromarray(im)
104
+ return im
105
+
106
+ # Load configuration
107
+ config = {
108
+ 'z_dim': 100,
109
+ 'cond_dim': 256,
110
+ 'imsize': 256,
111
+ 'nf': 32,
112
+ 'ch_size': 3,
113
+ 'truncation': True,
114
+ 'trunc_rate': 0.88,
115
+ }
116
+
117
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
118
+ print(f"Using device: {device}")
119
+
120
+ # Load vocab and models
121
+ def load_models():
122
+ # Load vocabulary
123
+ with open('data/captions_DAMSM.pickle', 'rb') as f:
124
+ x = pickle.load(f)
125
+ wordtoix = x[3]
126
+ ixtoword = x[2]
127
+ del x
128
+
129
+ # Initialize text encoder
130
+ text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim'])
131
+ text_encoder_path = 'data/text_encoder200.pth'
132
+ state_dict = torch.load(text_encoder_path, map_location='cpu')
133
+ text_encoder = load_model_weights(text_encoder, state_dict)
134
+ text_encoder.to(device)
135
+ for p in text_encoder.parameters():
136
+ p.requires_grad = False
137
+ text_encoder.eval()
138
+
139
+ # Initialize generator
140
+ netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size'])
141
+ netG_path = 'data/state_epoch_1220.pth'
142
+ state_dict = torch.load(netG_path, map_location='cpu')
143
+ netG = load_model_weights(netG, state_dict['model']['netG'])
144
+ netG.to(device)
145
+ netG.eval()
146
+
147
+ return wordtoix, ixtoword, text_encoder, netG
148
+
149
+ wordtoix, ixtoword, text_encoder, netG = load_models()
150
+
151
+ def generate_image(text_input, num_images=1, seed=None):
152
+ """Generate images from text description"""
153
+ if not text_input.strip():
154
+ return [None] * num_images
155
+
156
+ cap_array, cap_len = tokenize_and_build_captions(text_input, wordtoix)
157
+
158
+ if cap_len == 0:
159
+ return [Image.new('RGB', (256, 256), color='red')] * num_images
160
+
161
+ sent_emb = encode_caption(cap_array, cap_len, text_encoder, device)
162
+
163
+ # Set random seed if provided
164
+ if seed is not None:
165
+ random.seed(seed)
166
+ np.random.seed(seed)
167
+ torch.manual_seed(seed)
168
+ if torch.cuda.is_available():
169
+ torch.cuda.manual_seed_all(seed)
170
+
171
+ # Generate multiple images if requested
172
+ result_images = []
173
+ with torch.no_grad():
174
+ for _ in range(num_images):
175
+ # Generate noise
176
+ if config['truncation']:
177
+ noise = truncated_noise(1, config['z_dim'], config['trunc_rate'])
178
+ noise = torch.tensor(noise, dtype=torch.float).to(device)
179
+ else:
180
+ noise = torch.randn(1, config['z_dim']).to(device)
181
+
182
+ # Generate image
183
+ fake_img = netG(noise, sent_emb)
184
+ img = save_img(fake_img[0])
185
+ result_images.append(img)
186
+
187
+ return result_images
188
+
189
+ # Create Gradio interface
190
+ def generate_images_interface(text, num_images, random_seed):
191
+ seed = int(random_seed) if random_seed else None
192
+ return generate_image(text, num_images, seed)
193
+
194
+ with gr.Blocks(title="Bird Image Generator") as demo:
195
+ gr.Markdown("# Bird Image Generator using DF-GAN")
196
+ gr.Markdown("Enter a description of a bird and the model will generate corresponding images.")
197
+
198
+ with gr.Row():
199
+ with gr.Column():
200
+ text_input = gr.Textbox(
201
+ label="Bird Description",
202
+ placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')",
203
+ lines=3
204
+ )
205
+ num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images")
206
+ seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results")
207
+ submit_btn = gr.Button("Generate Image")
208
+
209
+ with gr.Column():
210
+ image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto")
211
+
212
+ submit_btn.click(
213
+ fn=generate_images_interface,
214
+ inputs=[text_input, num_images, seed],
215
+ outputs=image_output
216
+ )
217
+
218
+ gr.Markdown("## Example Descriptions")
219
+ example_descriptions = [
220
+ "this bird has an orange bill, a white belly and white eyebrows",
221
+ "a small bird with a red head, breast, and belly and black wings",
222
+ "this bird is yellow with black and has a long, pointy beak",
223
+ "this bird is white in color, and has a orange beak"
224
+ ]
225
+
226
+ gr.Examples(
227
+ examples=[[desc, 1, ""] for desc in example_descriptions],
228
+ inputs=[text_input, num_images, seed],
229
+ outputs=image_output,
230
+ fn=generate_images_interface
231
+ )
232
+
233
+ # Launch the app with appropriate configurations for Hugging Face Spaces
234
+ if __name__ == "__main__":
235
+ demo.launch(
236
+ server_name="0.0.0.0", # Bind to all network interfaces
237
+ share=False, # Don't use share links
238
+ favicon_path="https://raw.githubusercontent.com/tobran/DF-GAN/main/framework.png"
239
+ )
demo.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+
download_models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import gdown
5
+ import shutil
6
+ import nltk
7
+ from pathlib import Path
8
+
9
+ # Install NLTK data
10
+ nltk.download('punkt')
11
+
12
+ # Create directories
13
+ os.makedirs('DF-GAN/code/models', exist_ok=True)
14
+ os.makedirs('data', exist_ok=True)
15
+
16
+ # Clone the DF-GAN repository
17
+ if not os.path.exists('DF-GAN/.git'):
18
+ print("Cloning DF-GAN repository...")
19
+ subprocess.run(["git", "clone", "https://github.com/tobran/DF-GAN.git", "DF-GAN_temp"])
20
+
21
+ # Move only necessary files to avoid duplicates
22
+ shutil.copytree('DF-GAN_temp/code/models', 'DF-GAN/code/models', dirs_exist_ok=True)
23
+ shutil.copytree('DF-GAN_temp/code/lib', 'DF-GAN/code/lib', dirs_exist_ok=True)
24
+
25
+ # Clean up
26
+ shutil.rmtree('DF-GAN_temp')
27
+
28
+ print("Repository cloned and organized.")
29
+
30
+ # Download model files
31
+ # DF-GAN pretrained bird model
32
+ bird_model_url = 'https://drive.google.com/uc?id=1rzfcCvGwU8vLCrn5reWxmrAMms6WQGA6'
33
+ bird_model_path = 'data/state_epoch_1220.pth'
34
+
35
+ # Text encoder for birds
36
+ text_encoder_url = 'https://drive.google.com/uc?id=1xwIyLPYtYn9YGPIcRuWXxaxcw_oPGQK4'
37
+ text_encoder_path = 'data/text_encoder200.pth'
38
+
39
+ # Captions DAMSM pickle file
40
+ captions_pickle_url = 'https://drive.google.com/uc?id=1FfNMRpOZGaO3mKYyj2VDVEW1ChZ12lJp'
41
+ captions_pickle_path = 'data/captions_DAMSM.pickle'
42
+
43
+ # Download if files don't exist
44
+ if not os.path.exists(bird_model_path):
45
+ print(f"Downloading bird model to {bird_model_path}...")
46
+ gdown.download(bird_model_url, bird_model_path, quiet=False)
47
+
48
+ if not os.path.exists(text_encoder_path):
49
+ print(f"Downloading text encoder to {text_encoder_path}...")
50
+ gdown.download(text_encoder_url, text_encoder_path, quiet=False)
51
+
52
+ if not os.path.exists(captions_pickle_path):
53
+ print(f"Downloading captions pickle to {captions_pickle_path}...")
54
+ gdown.download(captions_pickle_url, captions_pickle_path, quiet=False)
55
+
56
+ print("All model files downloaded and prepared successfully!")
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask==2.0.1
2
+ torch>=1.9.0
3
+ torchvision>=0.10.0
4
+ Pillow>=9.0.0
5
+ nltk>=3.6.0
6
+ gunicorn==20.1.0
7
+ python-dotenv==0.19.0
8
+ requests==2.26.0
9
+ matplotlib==3.5.1
10
+ tqdm>=4.62.0
11
+ numpy>=1.20.0
12
+ scipy>=1.7.0
13
+ omegaconf>=2.1.0
14
+ gradio>=3.50.0
15
+ easydict>=1.9
16
+ gdown>=4.6.0
startup.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Install NLTK data
4
+ python -c "import nltk; nltk.download('punkt')"
5
+
6
+ # Run the download_models.py script to get the models
7
+ python download_models.py
8
+
9
+ # Start the Gradio app
10
+ python app.py