Spaces:
Runtime error
Runtime error
File size: 8,944 Bytes
5e37be9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import SiglipVisionModel, AutoTokenizer, AutoImageProcessor, AutoModel
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from tqdm import tqdm
import os
import numpy as np
from PIL import Image
import argparse
def siglip_loss(image_embeddings, text_embeddings, temperature=0.07):
# Normalize
image_embeddings = F.normalize(image_embeddings, dim=-1)
text_embeddings = F.normalize(text_embeddings, dim=-1)
# Compute pairwise similarities
logits = image_embeddings @ text_embeddings.T # [batch_size, batch_size]
logits = logits / temperature
# Ground truth: 1.0 for matching pairs (diagonal), 0.0 for all others
batch_size = logits.size(0)
targets = torch.eye(batch_size).to(logits.device)
# Apply binary cross-entropy with logits
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss
class LinearProjection(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
def get_text_embedding(text, tokenizer, device, max_length=128):
# Ensure text is not empty and has minimum content
if not text or len(text.strip()) == 0:
text = "This is a placeholder description."
# Tokenize with padding and truncation
inputs = tokenizer(
text,
return_tensors="pt",
padding='max_length', # Changed to max_length padding
truncation=True,
max_length=max_length # Fixed max length for all inputs
)
# Move inputs to device and ensure correct data type
inputs = {
k: v.to(device).float() for k, v in inputs.items()
}
# Return the input_ids as embeddings
return inputs['input_ids'].float() # Convert to float for the loss calculation
def main(num_images=100, batch_size=32, num_epochs=50, learning_rate=1e-4, load_checkpoint=True, checkpoint_path='linear_projection.pth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models and processors
siglip_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Freeze SigLIP model
for param in siglip_model.parameters():
param.requires_grad = False
siglip_model.to(device)
# Get SigLIP output dimension and text embedding dimension
# Create a proper dummy image (black image)
dummy_image = Image.new('RGB', (384, 384), color='black')
with torch.no_grad():
siglip_inputs = siglip_processor(dummy_image, return_tensors="pt").to(device)
siglip_outputs = siglip_model(**siglip_inputs)
siglip_output_dim = siglip_outputs.pooler_output.shape[-1]
# Get a sample text to determine embedding dimension
dummy_text = "This is a test."
dummy_embedding = get_text_embedding(dummy_text, tokenizer, device)
text_embedding_dim = dummy_embedding.shape[-1]
print(f"SigLIP output dimension: {siglip_output_dim}")
print(f"Text embedding dimension: {text_embedding_dim}")
# Create linear projection layer
linear_proj = LinearProjection(siglip_output_dim, text_embedding_dim).to(device)
# Load checkpoint if requested
if load_checkpoint:
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
linear_proj.load_state_dict(checkpoint)
print(f"Successfully loaded checkpoint from {checkpoint_path}")
except Exception as e:
print(f"Error loading checkpoint: {e}")
print("Starting training from scratch instead.")
# Load CIFAR10 test dataset
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
])
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
subset_indices = list(range(num_images))
subset_dataset = Subset(test_dataset, subset_indices)
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)
# Create text files directory if it doesn't exist
os.makedirs('qa_outputs', exist_ok=True)
# Optimizer
optimizer = AdamW(linear_proj.parameters(), lr=learning_rate)
# Training loop
for epoch in range(num_epochs):
total_loss = 0
linear_proj.train()
progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
for batch_idx, (images, labels) in enumerate(progress_bar):
images = images.to(device)
batch_size = images.size(0)
# Get image embeddings
with torch.no_grad():
siglip_inputs = siglip_processor(images, return_tensors="pt").to(device)
siglip_outputs = siglip_model(**siglip_inputs)
image_features = siglip_outputs.pooler_output
# Project image features
projected_image_features = linear_proj(image_features)
# Process text for each line (1 to 5)
total_batch_loss = 0
for line_num in range(5):
text_embeddings_list = []
# Read text from files for current batch
for idx in range(batch_size):
global_idx = batch_idx * batch_size + idx
if global_idx < num_images:
file_path = f'qa_outputs/image_{global_idx}_extr.txt'
try:
with open(file_path, 'r') as f:
lines = f.readlines()
text = lines[line_num].strip() if line_num < len(lines) else ""
except:
text = "No description available"
# Get text embeddings directly from tokenizer
text_embedding = get_text_embedding(text, tokenizer, device)
text_embeddings_list.append(text_embedding)
if text_embeddings_list:
# Stack instead of cat since all embeddings have same size now
text_embeddings = torch.stack(text_embeddings_list, dim=0).squeeze(1)
loss = siglip_loss(projected_image_features, text_embeddings)
total_batch_loss += loss
# Average loss over all text lines
avg_batch_loss = total_batch_loss / 5
# Backpropagation
optimizer.zero_grad()
avg_batch_loss.backward()
optimizer.step()
total_loss += avg_batch_loss.item()
progress_bar.set_postfix({'loss': avg_batch_loss.item()})
avg_epoch_loss = total_loss / len(dataloader)
print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}')
# Save checkpoint after each epoch
# checkpoint_dir = 'checkpoints'
# os.makedirs(checkpoint_dir, exist_ok=True)
# checkpoint_file = os.path.join(checkpoint_dir, f'linear_projection_epoch_{epoch+1}.pth')
# torch.save(linear_proj.state_dict(), checkpoint_file)
# print(f"Saved checkpoint to {checkpoint_file}")
# Save final model
torch.save(linear_proj.state_dict(), 'linear_projection_final.pth')
print("Training completed. Final model saved as 'linear_projection_final.pth'")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train or continue training the linear projection layer')
parser.add_argument('--num_images', type=int, default=100, help='Number of images to train on')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs to train')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--load_checkpoint', action='store_true', help='Whether to load from checkpoint')
parser.add_argument('--checkpoint_path', type=str, default='linear_projection.pth', help='Path to checkpoint file')
args = parser.parse_args()
main(
num_images=args.num_images,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
load_checkpoint=args.load_checkpoint,
checkpoint_path=args.checkpoint_path
) |