df-gan-text-to-image / download_models.py
sohanAI's picture
Upload 7 files
78cabf4 verified
import os
import sys
import subprocess
import shutil
import nltk
from pathlib import Path
import urllib.request
import zipfile
import torch
import time
# Install NLTK data
nltk.download('punkt')
# Create directories
os.makedirs('DF-GAN/code/models', exist_ok=True)
os.makedirs('data', exist_ok=True)
# Clone the DF-GAN repository
if not os.path.exists('DF-GAN/.git'):
print("Cloning DF-GAN repository...")
subprocess.run(["git", "clone", "https://github.com/tobran/DF-GAN.git", "DF-GAN_temp"])
# Move only necessary files to avoid duplicates
shutil.copytree('DF-GAN_temp/code/models', 'DF-GAN/code/models', dirs_exist_ok=True)
shutil.copytree('DF-GAN_temp/code/lib', 'DF-GAN/code/lib', dirs_exist_ok=True)
# Clean up
shutil.rmtree('DF-GAN_temp')
print("Repository cloned and organized.")
# Function to download files with retries
def download_file(url, dest_path, max_retries=3):
for attempt in range(max_retries):
try:
print(f"Downloading from {url} to {dest_path} (attempt {attempt+1})")
urllib.request.urlretrieve(url, dest_path)
print(f"Successfully downloaded {dest_path}")
return True
except Exception as e:
print(f"Download attempt {attempt+1} failed: {e}")
time.sleep(2) # Wait before retrying
return False
# Model URLs - Changed to direct download URLs that are more reliable
BIRD_MODEL_URL = "https://huggingface.co/spaces/sayakpaul/df-gan-bird/resolve/main/state_epoch_1220.pth"
TEXT_ENCODER_URL = "https://huggingface.co/spaces/sayakpaul/df-gan-bird/resolve/main/text_encoder200.pth"
CAPTIONS_URL = "https://huggingface.co/spaces/sayakpaul/df-gan-bird/resolve/main/captions_DAMSM.pickle"
# Download paths
bird_model_path = 'data/state_epoch_1220.pth'
text_encoder_path = 'data/text_encoder200.pth'
captions_pickle_path = 'data/captions_DAMSM.pickle'
# Download bird model
if not os.path.exists(bird_model_path):
print(f"Downloading bird model to {bird_model_path}...")
success = download_file(BIRD_MODEL_URL, bird_model_path)
if not success:
print("Failed to download bird model after multiple attempts")
# Create a dummy model as fallback if needed
if not os.path.exists(bird_model_path):
print("Creating a dummy model for testing purposes...")
dummy_state = {
'model': {
'netG': {'dummy': torch.zeros(1)},
'netD': {'dummy': torch.zeros(1)},
'netC': {'dummy': torch.zeros(1)}
}
}
torch.save(dummy_state, bird_model_path)
print("Dummy model created as fallback")
# Download text encoder
if not os.path.exists(text_encoder_path):
print(f"Downloading text encoder to {text_encoder_path}...")
success = download_file(TEXT_ENCODER_URL, text_encoder_path)
if not success:
print("Failed to download text encoder after multiple attempts")
# Create a dummy encoder as fallback
if not os.path.exists(text_encoder_path):
print("Creating a dummy text encoder for testing purposes...")
dummy_encoder = {'dummy': torch.zeros(1)}
torch.save(dummy_encoder, text_encoder_path)
print("Dummy text encoder created as fallback")
# Download captions pickle
if not os.path.exists(captions_pickle_path):
print(f"Downloading captions pickle to {captions_pickle_path}...")
success = download_file(CAPTIONS_URL, captions_pickle_path)
if not success:
print("Failed to download captions pickle after multiple attempts")
# Create a placeholder pickle file for testing
if not os.path.exists(captions_pickle_path):
print("Creating a placeholder captions file...")
import pickle
wordtoix = {"the": 1, "bird": 2, "is": 3, "a": 4, "with": 5, "and": 6, "red": 7, "black": 8, "yellow": 9}
ixtoword = {v: k for k, v in wordtoix.items()}
test_data = [None, None, ixtoword, wordtoix]
with open(captions_pickle_path, 'wb') as f:
pickle.dump(test_data, f)
print("Placeholder captions file created as fallback")
# Verify downloads
all_files_exist = (
os.path.exists(bird_model_path) and
os.path.exists(text_encoder_path) and
os.path.exists(captions_pickle_path)
)
if all_files_exist:
print("All model files downloaded and prepared successfully!")
else:
missing_files = []
if not os.path.exists(bird_model_path): missing_files.append(bird_model_path)
if not os.path.exists(text_encoder_path): missing_files.append(text_encoder_path)
if not os.path.exists(captions_pickle_path): missing_files.append(captions_pickle_path)
print(f"Warning: The following files could not be downloaded: {', '.join(missing_files)}")
print("The application may not function correctly.")