Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
import random | |
import numpy as np | |
from transformers import ( | |
SiglipVisionModel, | |
AutoTokenizer, | |
AutoImageProcessor, | |
AutoModelForCausalLM | |
) | |
from peft import PeftModel | |
from PIL import Image | |
# Initialize device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load models and processors | |
def load_models(): | |
# Load SigLIP | |
print("Loading SigLIP model...") | |
siglip_model = SiglipVisionModel.from_pretrained( | |
"google/siglip-so400m-patch14-384", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to(device) | |
siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") | |
# Load base Phi-3 model | |
print("Loading Phi-3 model...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/phi-2", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to(device) | |
# Load the trained LoRA weights | |
print("Loading trained LoRA weights...") | |
phi_model = PeftModel.from_pretrained( | |
base_model, | |
"phi_model_trained" | |
) | |
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") | |
if phi_tokenizer.pad_token is None: | |
phi_tokenizer.pad_token = phi_tokenizer.eos_token | |
# Load trained projections | |
print("Loading projection layers...") | |
linear_proj = torch.load('linear_projection_final.pth', map_location=device) | |
image_text_proj = torch.load('image_text_proj.pth', map_location=device) | |
return (siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj) | |
# Load all models at startup | |
print("Loading models...") | |
models = load_models() | |
siglip_model, siglip_processor, phi_model, phi_tokenizer, linear_proj, image_text_proj = models | |
print("Models loaded successfully!") | |
# Load CIFAR10 test dataset | |
transform = transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
]) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | |
# Get first 100 images | |
first_100_images = [(images, labels) for images, labels in list(testset)[:100]] | |
# Questions list | |
questions = [ | |
"Give a description of the image?", | |
"How does the main object in the image look like?", | |
"How can the main object in the image be useful to humans?", | |
"What is the color of the main object in the image?", | |
"Describe the setting of the image?" | |
] | |
def get_image_embedding(image, siglip_model, siglip_processor, linear_proj, device): | |
with torch.no_grad(): | |
# Process image through SigLIP | |
inputs = siglip_processor(image, return_tensors="pt") | |
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
outputs = siglip_model(**inputs) | |
image_features = outputs.pooler_output | |
projected_features = linear_proj(image_features) | |
return projected_features | |
def get_random_images(): | |
# Select 10 random images from first 100 | |
selected_indices = random.sample(range(100), 10) | |
selected_images = [first_100_images[i][0] for i in selected_indices] | |
# Convert to numpy arrays and transpose to correct format (H,W,C) | |
images_np = [img.permute(1, 2, 0).numpy() for img in selected_images] | |
return images_np, selected_indices | |
def generate_answer(image_tensor, question_index): | |
if image_tensor is None: | |
return "Please select an image first!" | |
try: | |
# Get image embedding | |
image_embedding = get_image_embedding( | |
image_tensor, | |
siglip_model, | |
siglip_processor, | |
linear_proj, | |
device | |
) | |
# Get question | |
question = questions[question_index] | |
# Tokenize question | |
question_tokens = phi_tokenizer( | |
question, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(device) | |
# Get question embeddings | |
question_embeds = phi_model.get_input_embeddings()(question_tokens['input_ids']) | |
# Project and prepare image embeddings | |
image_embeds = image_text_proj(image_embedding) | |
image_embeds = image_embeds.unsqueeze(1) | |
# Combine embeddings | |
combined_embedding = torch.cat([ | |
image_embeds, | |
question_embeds | |
], dim=1) | |
# Create attention mask | |
attention_mask = torch.ones( | |
(1, combined_embedding.size(1)), | |
dtype=torch.long, | |
device=device | |
) | |
# Generate answer | |
with torch.no_grad(): | |
outputs = phi_model.generate( | |
inputs_embeds=combined_embedding, | |
attention_mask=attention_mask, | |
max_new_tokens=100, | |
num_beams=4, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=phi_tokenizer.pad_token_id, | |
eos_token_id=phi_tokenizer.eos_token_id | |
) | |
# Decode the generated answer | |
answer = phi_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return answer | |
except Exception as e: | |
return f"Error generating answer: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# CIFAR10 Image Question Answering System") | |
# State variables | |
selected_image_tensor = gr.State(None) | |
image_indices = gr.State([]) | |
with gr.Row(): | |
with gr.Column(): | |
random_btn = gr.Button("Get Random Images") | |
gallery = gr.Gallery( | |
label="Click an image to select it", | |
show_label=True, | |
elem_id="gallery", | |
columns=[5], | |
rows=[2], | |
height="auto", | |
allow_preview=False | |
) | |
with gr.Column(): | |
selected_img = gr.Image(label="Selected Image", height=200) | |
q_buttons = [] | |
for i, q in enumerate(questions): | |
btn = gr.Button(f"Q{i+1}: {q}") | |
q_buttons.append(btn) | |
answer_box = gr.Textbox(label="Answer", lines=3) | |
def on_random_click(): | |
images, indices = get_random_images() | |
return { | |
gallery: images, | |
image_indices: indices, | |
selected_image_tensor: None, | |
selected_img: None, | |
answer_box: "" | |
} | |
random_btn.click( | |
on_random_click, | |
outputs=[gallery, image_indices, selected_image_tensor, selected_img, answer_box] | |
) | |
def on_image_select(evt: gr.SelectData, images, indices): | |
if images is None or evt.index >= len(images): | |
return None, None, "" | |
selected_idx = indices[evt.index] | |
selected_tensor = first_100_images[selected_idx][0] | |
return selected_tensor, images[evt.index], "" | |
gallery.select( | |
on_image_select, | |
inputs=[gallery, image_indices], | |
outputs=[selected_image_tensor, selected_img, answer_box] | |
) | |
for i, btn in enumerate(q_buttons): | |
btn.click( | |
generate_answer, | |
inputs=[selected_image_tensor, gr.Number(value=i, visible=False)], | |
outputs=answer_box | |
) | |
# Launch with minimal settings | |
demo.queue(max_size=1).launch(show_error=True) |