import gradio as gr from PIL import Image import torch from transformers import AutoProcessor, BlipForConditionalGeneration, BlipProcessor import os # Check if we're running on CPU or GPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load processor first try: # Try to load the custom processor processor = AutoProcessor.from_pretrained("./processor") print("Loaded custom processor") except Exception as e: print(f"Failed to load custom processor: {e}") # Fall back to a smaller processor processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") print("Using fallback processor") # Load base model - use the smallest possible model for CPU try: # Try loading the smallest BLIP model model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", torch_dtype=torch.float32 # Use float32 for CPU compatibility ) print("Loaded base BLIP model") except Exception as e: print(f"Error loading model: {e}") # If that fails, load with low memory usage model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", low_cpu_mem_usage=True ) print("Loaded fallback model with low memory settings") # Move model to device if needed model = model.to(device) print("Model loaded and ready") # Define the function to generate caption def generate_caption(image): try: # Convert image to RGB if needed image = image.convert('RGB') if image.mode != 'RGB' else image # Process the image inputs = processor(images=image, return_tensors="pt").to(device) # Generate caption generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=25) # Decode the caption caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return caption except Exception as e: return f"Error generating caption: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=generate_caption, inputs=gr.Image(type="pil"), outputs="text", title="Image Caption Generator", description="Upload an image to generate a caption.", examples=["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"] ) # Launch iface.launch()