Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, BlipForConditionalGeneration, BlipProcessor | |
import os | |
# Check if we're running on CPU or GPU | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Load processor first | |
try: | |
# Try to load the custom processor | |
processor = AutoProcessor.from_pretrained("./processor") | |
print("Loaded custom processor") | |
except Exception as e: | |
print(f"Failed to load custom processor: {e}") | |
# Fall back to a smaller processor | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
print("Using fallback processor") | |
# Load base model - use the smallest possible model for CPU | |
try: | |
# Try loading the smallest BLIP model | |
model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-base", | |
torch_dtype=torch.float32 # Use float32 for CPU compatibility | |
) | |
print("Loaded base BLIP model") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# If that fails, load with low memory usage | |
model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-base", | |
low_cpu_mem_usage=True | |
) | |
print("Loaded fallback model with low memory settings") | |
# Move model to device if needed | |
model = model.to(device) | |
print("Model loaded and ready") | |
# Define the function to generate caption | |
def generate_caption(image): | |
try: | |
# Convert image to RGB if needed | |
image = image.convert('RGB') if image.mode != 'RGB' else image | |
# Process the image | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
# Generate caption | |
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25) | |
# Decode the caption | |
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return caption | |
except Exception as e: | |
return f"Error generating caption: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_caption, | |
inputs=gr.Image(type="pil"), | |
outputs="text", | |
title="Image Caption Generator", | |
description="Upload an image to generate a caption.", | |
examples=["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"] | |
) | |
# Launch | |
iface.launch() |