Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- .gitignore +36 -0
- .huggingface-space +9 -0
- README-HF.md +31 -0
- README.md +34 -13
- app.py +239 -0
- demo.ipynb +1 -0
- download_models.py +56 -0
- requirements.txt +16 -0
- startup.sh +10 -0
.gitignore
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
env/
|
8 |
+
build/
|
9 |
+
develop-eggs/
|
10 |
+
dist/
|
11 |
+
downloads/
|
12 |
+
eggs/
|
13 |
+
.eggs/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
*.egg-info/
|
19 |
+
.installed.cfg
|
20 |
+
*.egg
|
21 |
+
|
22 |
+
# Jupyter Notebook
|
23 |
+
.ipynb_checkpoints
|
24 |
+
|
25 |
+
# Data directories
|
26 |
+
data/
|
27 |
+
DF-GAN/
|
28 |
+
|
29 |
+
# Model files
|
30 |
+
*.pth
|
31 |
+
*.pickle
|
32 |
+
*.npz
|
33 |
+
|
34 |
+
# Generated images
|
35 |
+
samples/
|
36 |
+
.DS_Store
|
.huggingface-space
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title: DF-GAN Bird Image Generator
|
2 |
+
emoji: 🐦
|
3 |
+
colorFrom: blue
|
4 |
+
colorTo: purple
|
5 |
+
sdk: gradio
|
6 |
+
sdk_version: 3.50.0
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
license: cc-by-nc-sa-4.0
|
README-HF.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DF-GAN Bird Image Generator 🐦
|
2 |
+
|
3 |
+
This Hugging Face Space demonstrates the [DF-GAN model](https://arxiv.org/abs/2008.05865) for generating bird images from text descriptions.
|
4 |
+
|
5 |
+
## How It Works
|
6 |
+
|
7 |
+
1. Enter a text description of a bird
|
8 |
+
2. Select how many images you want to generate (1-4)
|
9 |
+
3. Optionally add a random seed for reproducible results
|
10 |
+
4. Click "Generate Image"
|
11 |
+
5. The model will generate realistic bird images based on your description
|
12 |
+
|
13 |
+
## Example Descriptions
|
14 |
+
|
15 |
+
Try these example descriptions:
|
16 |
+
- "this bird has an orange bill, a white belly and white eyebrows"
|
17 |
+
- "a small bird with a red head, breast, and belly and black wings"
|
18 |
+
- "this bird is yellow with black and has a long, pointy beak"
|
19 |
+
- "this is a grey bodied bird with light grey wings and a white breast"
|
20 |
+
|
21 |
+
## About the Model
|
22 |
+
|
23 |
+
The DF-GAN (Deep Fusion GAN) model is a text-to-image synthesis model introduced in the paper "DF-GAN: A Simple and Effective Baseline for Text-to-Image Synthesis" (CVPR 2022). This demo uses the pre-trained bird model that was trained on the CUB-200-2011 dataset.
|
24 |
+
|
25 |
+
This demo runs on CPU, so image generation may take a few seconds.
|
26 |
+
|
27 |
+
## Credits
|
28 |
+
|
29 |
+
This Space uses the official implementation of DF-GAN from [tobran/DF-GAN](https://github.com/tobran/DF-GAN).
|
30 |
+
|
31 |
+
Made with ❤️ by [Your Name]
|
README.md
CHANGED
@@ -1,13 +1,34 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DF-GAN Bird Image Generator
|
2 |
+
|
3 |
+
This application uses the DF-GAN (Deep Fusion GAN) model to generate bird images based on text descriptions. Just enter a description of a bird, and the model will generate a realistic image that matches your description.
|
4 |
+
|
5 |
+
## About the Model
|
6 |
+
|
7 |
+
This application uses the pre-trained bird model from the [DF-GAN: A Simple and Effective Baseline for Text-to-Image Synthesis](https://arxiv.org/abs/2008.05865) paper (CVPR 2022). DF-GAN is a text-to-image synthesis model that can generate high-quality images from textual descriptions.
|
8 |
+
|
9 |
+
## How to Use
|
10 |
+
|
11 |
+
1. Enter a description of a bird in the text box (e.g., "a yellow bird with a black head")
|
12 |
+
2. Choose how many images you want to generate (1-4)
|
13 |
+
3. Optionally, set a random seed for reproducible results
|
14 |
+
4. Click "Generate Image" button
|
15 |
+
5. View the generated bird images that match your description
|
16 |
+
|
17 |
+
## Examples
|
18 |
+
|
19 |
+
Try these example descriptions:
|
20 |
+
- "this bird has an orange bill, a white belly and white eyebrows"
|
21 |
+
- "a small bird with a red head, breast, and belly and black wings"
|
22 |
+
- "this bird is yellow with black and has a long, pointy beak"
|
23 |
+
- "this bird is white in color, and has a orange beak"
|
24 |
+
|
25 |
+
## Implementation Details
|
26 |
+
|
27 |
+
This application uses the following components:
|
28 |
+
- DF-GAN architecture for text-to-image synthesis
|
29 |
+
- DAMSM text encoder for embedding text descriptions
|
30 |
+
- Gradio for the web interface
|
31 |
+
|
32 |
+
## Credits
|
33 |
+
|
34 |
+
This implementation is based on the official DF-GAN repository: [tobran/DF-GAN](https://github.com/tobran/DF-GAN)
|
app.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import pickle
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import gradio as gr
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from scipy.stats import truncnorm
|
12 |
+
import subprocess
|
13 |
+
|
14 |
+
# First run the download_models.py script if models haven't been downloaded
|
15 |
+
if not os.path.exists('data/state_epoch_1220.pth') or not os.path.exists('data/text_encoder200.pth'):
|
16 |
+
print("Downloading necessary model files...")
|
17 |
+
try:
|
18 |
+
subprocess.check_call([sys.executable, "download_models.py"])
|
19 |
+
except subprocess.CalledProcessError as e:
|
20 |
+
print(f"Error downloading models: {e}")
|
21 |
+
print("Please run download_models.py manually before starting the app.")
|
22 |
+
|
23 |
+
# Add the code directory to the Python path
|
24 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "DF-GAN/code"))
|
25 |
+
|
26 |
+
# Import necessary modules from the DF-GAN code
|
27 |
+
from models.DAMSM import RNN_ENCODER
|
28 |
+
from models.GAN import NetG
|
29 |
+
|
30 |
+
# Utility functions
|
31 |
+
def load_model_weights(model, weights, multi_gpus=False, train=False):
|
32 |
+
"""Load model weights with proper handling of module prefix"""
|
33 |
+
if list(weights.keys())[0].find('module')==-1:
|
34 |
+
pretrained_with_multi_gpu = False
|
35 |
+
else:
|
36 |
+
pretrained_with_multi_gpu = True
|
37 |
+
|
38 |
+
if (multi_gpus==False) or (train==False):
|
39 |
+
if pretrained_with_multi_gpu:
|
40 |
+
state_dict = {
|
41 |
+
key[7:]: value
|
42 |
+
for key, value in weights.items()
|
43 |
+
}
|
44 |
+
else:
|
45 |
+
state_dict = weights
|
46 |
+
else:
|
47 |
+
state_dict = weights
|
48 |
+
|
49 |
+
model.load_state_dict(state_dict)
|
50 |
+
return model
|
51 |
+
|
52 |
+
def get_tokenizer():
|
53 |
+
"""Get NLTK tokenizer"""
|
54 |
+
from nltk.tokenize import RegexpTokenizer
|
55 |
+
tokenizer = RegexpTokenizer(r'\w+')
|
56 |
+
return tokenizer
|
57 |
+
|
58 |
+
def truncated_noise(batch_size=1, dim_z=100, truncation=1.0, seed=None):
|
59 |
+
"""Generate truncated noise"""
|
60 |
+
state = None if seed is None else np.random.RandomState(seed)
|
61 |
+
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
|
62 |
+
return truncation * values
|
63 |
+
|
64 |
+
def tokenize_and_build_captions(input_text, wordtoix):
|
65 |
+
"""Tokenize text and convert to indices using wordtoix mapping"""
|
66 |
+
tokenizer = get_tokenizer()
|
67 |
+
tokens = tokenizer.tokenize(input_text.lower())
|
68 |
+
cap = []
|
69 |
+
for t in tokens:
|
70 |
+
t = t.encode('ascii', 'ignore').decode('ascii')
|
71 |
+
if len(t) > 0 and t in wordtoix:
|
72 |
+
cap.append(wordtoix[t])
|
73 |
+
|
74 |
+
# Create padded array for the caption
|
75 |
+
max_len = 18 # As defined in the bird.yml
|
76 |
+
cap_array = np.zeros(max_len, dtype='int64')
|
77 |
+
cap_len = len(cap)
|
78 |
+
if cap_len <= max_len:
|
79 |
+
cap_array[:cap_len] = cap
|
80 |
+
else:
|
81 |
+
# Truncate if too long
|
82 |
+
cap_array = cap[:max_len]
|
83 |
+
cap_len = max_len
|
84 |
+
|
85 |
+
return cap_array, cap_len
|
86 |
+
|
87 |
+
def encode_caption(caption, caption_len, text_encoder, device):
|
88 |
+
"""Encode caption using text encoder"""
|
89 |
+
with torch.no_grad():
|
90 |
+
caption = torch.tensor([caption]).to(device)
|
91 |
+
caption_len = torch.tensor([caption_len]).to(device)
|
92 |
+
hidden = text_encoder.init_hidden(1)
|
93 |
+
_, sent_emb = text_encoder(caption, caption_len, hidden)
|
94 |
+
return sent_emb
|
95 |
+
|
96 |
+
def save_img(img_tensor):
|
97 |
+
"""Convert image tensor to PIL Image"""
|
98 |
+
im = img_tensor.data.cpu().numpy()
|
99 |
+
# [-1, 1] --> [0, 255]
|
100 |
+
im = (im + 1.0) * 127.5
|
101 |
+
im = im.astype(np.uint8)
|
102 |
+
im = np.transpose(im, (1, 2, 0))
|
103 |
+
im = Image.fromarray(im)
|
104 |
+
return im
|
105 |
+
|
106 |
+
# Load configuration
|
107 |
+
config = {
|
108 |
+
'z_dim': 100,
|
109 |
+
'cond_dim': 256,
|
110 |
+
'imsize': 256,
|
111 |
+
'nf': 32,
|
112 |
+
'ch_size': 3,
|
113 |
+
'truncation': True,
|
114 |
+
'trunc_rate': 0.88,
|
115 |
+
}
|
116 |
+
|
117 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
118 |
+
print(f"Using device: {device}")
|
119 |
+
|
120 |
+
# Load vocab and models
|
121 |
+
def load_models():
|
122 |
+
# Load vocabulary
|
123 |
+
with open('data/captions_DAMSM.pickle', 'rb') as f:
|
124 |
+
x = pickle.load(f)
|
125 |
+
wordtoix = x[3]
|
126 |
+
ixtoword = x[2]
|
127 |
+
del x
|
128 |
+
|
129 |
+
# Initialize text encoder
|
130 |
+
text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim'])
|
131 |
+
text_encoder_path = 'data/text_encoder200.pth'
|
132 |
+
state_dict = torch.load(text_encoder_path, map_location='cpu')
|
133 |
+
text_encoder = load_model_weights(text_encoder, state_dict)
|
134 |
+
text_encoder.to(device)
|
135 |
+
for p in text_encoder.parameters():
|
136 |
+
p.requires_grad = False
|
137 |
+
text_encoder.eval()
|
138 |
+
|
139 |
+
# Initialize generator
|
140 |
+
netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size'])
|
141 |
+
netG_path = 'data/state_epoch_1220.pth'
|
142 |
+
state_dict = torch.load(netG_path, map_location='cpu')
|
143 |
+
netG = load_model_weights(netG, state_dict['model']['netG'])
|
144 |
+
netG.to(device)
|
145 |
+
netG.eval()
|
146 |
+
|
147 |
+
return wordtoix, ixtoword, text_encoder, netG
|
148 |
+
|
149 |
+
wordtoix, ixtoword, text_encoder, netG = load_models()
|
150 |
+
|
151 |
+
def generate_image(text_input, num_images=1, seed=None):
|
152 |
+
"""Generate images from text description"""
|
153 |
+
if not text_input.strip():
|
154 |
+
return [None] * num_images
|
155 |
+
|
156 |
+
cap_array, cap_len = tokenize_and_build_captions(text_input, wordtoix)
|
157 |
+
|
158 |
+
if cap_len == 0:
|
159 |
+
return [Image.new('RGB', (256, 256), color='red')] * num_images
|
160 |
+
|
161 |
+
sent_emb = encode_caption(cap_array, cap_len, text_encoder, device)
|
162 |
+
|
163 |
+
# Set random seed if provided
|
164 |
+
if seed is not None:
|
165 |
+
random.seed(seed)
|
166 |
+
np.random.seed(seed)
|
167 |
+
torch.manual_seed(seed)
|
168 |
+
if torch.cuda.is_available():
|
169 |
+
torch.cuda.manual_seed_all(seed)
|
170 |
+
|
171 |
+
# Generate multiple images if requested
|
172 |
+
result_images = []
|
173 |
+
with torch.no_grad():
|
174 |
+
for _ in range(num_images):
|
175 |
+
# Generate noise
|
176 |
+
if config['truncation']:
|
177 |
+
noise = truncated_noise(1, config['z_dim'], config['trunc_rate'])
|
178 |
+
noise = torch.tensor(noise, dtype=torch.float).to(device)
|
179 |
+
else:
|
180 |
+
noise = torch.randn(1, config['z_dim']).to(device)
|
181 |
+
|
182 |
+
# Generate image
|
183 |
+
fake_img = netG(noise, sent_emb)
|
184 |
+
img = save_img(fake_img[0])
|
185 |
+
result_images.append(img)
|
186 |
+
|
187 |
+
return result_images
|
188 |
+
|
189 |
+
# Create Gradio interface
|
190 |
+
def generate_images_interface(text, num_images, random_seed):
|
191 |
+
seed = int(random_seed) if random_seed else None
|
192 |
+
return generate_image(text, num_images, seed)
|
193 |
+
|
194 |
+
with gr.Blocks(title="Bird Image Generator") as demo:
|
195 |
+
gr.Markdown("# Bird Image Generator using DF-GAN")
|
196 |
+
gr.Markdown("Enter a description of a bird and the model will generate corresponding images.")
|
197 |
+
|
198 |
+
with gr.Row():
|
199 |
+
with gr.Column():
|
200 |
+
text_input = gr.Textbox(
|
201 |
+
label="Bird Description",
|
202 |
+
placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')",
|
203 |
+
lines=3
|
204 |
+
)
|
205 |
+
num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images")
|
206 |
+
seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results")
|
207 |
+
submit_btn = gr.Button("Generate Image")
|
208 |
+
|
209 |
+
with gr.Column():
|
210 |
+
image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto")
|
211 |
+
|
212 |
+
submit_btn.click(
|
213 |
+
fn=generate_images_interface,
|
214 |
+
inputs=[text_input, num_images, seed],
|
215 |
+
outputs=image_output
|
216 |
+
)
|
217 |
+
|
218 |
+
gr.Markdown("## Example Descriptions")
|
219 |
+
example_descriptions = [
|
220 |
+
"this bird has an orange bill, a white belly and white eyebrows",
|
221 |
+
"a small bird with a red head, breast, and belly and black wings",
|
222 |
+
"this bird is yellow with black and has a long, pointy beak",
|
223 |
+
"this bird is white in color, and has a orange beak"
|
224 |
+
]
|
225 |
+
|
226 |
+
gr.Examples(
|
227 |
+
examples=[[desc, 1, ""] for desc in example_descriptions],
|
228 |
+
inputs=[text_input, num_images, seed],
|
229 |
+
outputs=image_output,
|
230 |
+
fn=generate_images_interface
|
231 |
+
)
|
232 |
+
|
233 |
+
# Launch the app with appropriate configurations for Hugging Face Spaces
|
234 |
+
if __name__ == "__main__":
|
235 |
+
demo.launch(
|
236 |
+
server_name="0.0.0.0", # Bind to all network interfaces
|
237 |
+
share=False, # Don't use share links
|
238 |
+
favicon_path="https://raw.githubusercontent.com/tobran/DF-GAN/main/framework.png"
|
239 |
+
)
|
demo.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
download_models.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import subprocess
|
4 |
+
import gdown
|
5 |
+
import shutil
|
6 |
+
import nltk
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# Install NLTK data
|
10 |
+
nltk.download('punkt')
|
11 |
+
|
12 |
+
# Create directories
|
13 |
+
os.makedirs('DF-GAN/code/models', exist_ok=True)
|
14 |
+
os.makedirs('data', exist_ok=True)
|
15 |
+
|
16 |
+
# Clone the DF-GAN repository
|
17 |
+
if not os.path.exists('DF-GAN/.git'):
|
18 |
+
print("Cloning DF-GAN repository...")
|
19 |
+
subprocess.run(["git", "clone", "https://github.com/tobran/DF-GAN.git", "DF-GAN_temp"])
|
20 |
+
|
21 |
+
# Move only necessary files to avoid duplicates
|
22 |
+
shutil.copytree('DF-GAN_temp/code/models', 'DF-GAN/code/models', dirs_exist_ok=True)
|
23 |
+
shutil.copytree('DF-GAN_temp/code/lib', 'DF-GAN/code/lib', dirs_exist_ok=True)
|
24 |
+
|
25 |
+
# Clean up
|
26 |
+
shutil.rmtree('DF-GAN_temp')
|
27 |
+
|
28 |
+
print("Repository cloned and organized.")
|
29 |
+
|
30 |
+
# Download model files
|
31 |
+
# DF-GAN pretrained bird model
|
32 |
+
bird_model_url = 'https://drive.google.com/uc?id=1rzfcCvGwU8vLCrn5reWxmrAMms6WQGA6'
|
33 |
+
bird_model_path = 'data/state_epoch_1220.pth'
|
34 |
+
|
35 |
+
# Text encoder for birds
|
36 |
+
text_encoder_url = 'https://drive.google.com/uc?id=1xwIyLPYtYn9YGPIcRuWXxaxcw_oPGQK4'
|
37 |
+
text_encoder_path = 'data/text_encoder200.pth'
|
38 |
+
|
39 |
+
# Captions DAMSM pickle file
|
40 |
+
captions_pickle_url = 'https://drive.google.com/uc?id=1FfNMRpOZGaO3mKYyj2VDVEW1ChZ12lJp'
|
41 |
+
captions_pickle_path = 'data/captions_DAMSM.pickle'
|
42 |
+
|
43 |
+
# Download if files don't exist
|
44 |
+
if not os.path.exists(bird_model_path):
|
45 |
+
print(f"Downloading bird model to {bird_model_path}...")
|
46 |
+
gdown.download(bird_model_url, bird_model_path, quiet=False)
|
47 |
+
|
48 |
+
if not os.path.exists(text_encoder_path):
|
49 |
+
print(f"Downloading text encoder to {text_encoder_path}...")
|
50 |
+
gdown.download(text_encoder_url, text_encoder_path, quiet=False)
|
51 |
+
|
52 |
+
if not os.path.exists(captions_pickle_path):
|
53 |
+
print(f"Downloading captions pickle to {captions_pickle_path}...")
|
54 |
+
gdown.download(captions_pickle_url, captions_pickle_path, quiet=False)
|
55 |
+
|
56 |
+
print("All model files downloaded and prepared successfully!")
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flask==2.0.1
|
2 |
+
torch>=1.9.0
|
3 |
+
torchvision>=0.10.0
|
4 |
+
Pillow>=9.0.0
|
5 |
+
nltk>=3.6.0
|
6 |
+
gunicorn==20.1.0
|
7 |
+
python-dotenv==0.19.0
|
8 |
+
requests==2.26.0
|
9 |
+
matplotlib==3.5.1
|
10 |
+
tqdm>=4.62.0
|
11 |
+
numpy>=1.20.0
|
12 |
+
scipy>=1.7.0
|
13 |
+
omegaconf>=2.1.0
|
14 |
+
gradio>=3.50.0
|
15 |
+
easydict>=1.9
|
16 |
+
gdown>=4.6.0
|
startup.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Install NLTK data
|
4 |
+
python -c "import nltk; nltk.download('punkt')"
|
5 |
+
|
6 |
+
# Run the download_models.py script to get the models
|
7 |
+
python download_models.py
|
8 |
+
|
9 |
+
# Start the Gradio app
|
10 |
+
python app.py
|