therealsaed's picture
Upload 2 files
2ea5846 verified
raw
history blame
2.46 kB
"""
Hugging Face Spaces App
Deploy this to HF Spaces for free hosting
"""
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
# Load models
print("Loading models...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
git_processor = AutoProcessor.from_pretrained("microsoft/git-base")
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
def generate_captions(image, true_caption=""):
"""Generate captions using multiple models"""
if image is None:
return "Please upload an image first."
results = []
# BLIP model
try:
inputs = blip_processor(image, return_tensors="pt")
out = blip_model.generate(**inputs, max_length=50)
blip_caption = blip_processor.decode(out[0], skip_special_tokens=True)
results.append(f"**BLIP:** {blip_caption}")
except Exception as e:
results.append(f"**BLIP:** Error - {str(e)}")
# GIT model
try:
inputs = git_processor(images=image, return_tensors="pt")
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=50)
git_caption = git_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
results.append(f"**GIT:** {git_caption}")
except Exception as e:
results.append(f"**GIT:** Error - {str(e)}")
if true_caption:
results.insert(0, f"**True Caption:** {true_caption}")
return "\n\n".join(results)
# Create Gradio interface
demo = gr.Interface(
fn=generate_captions,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="True Caption (Optional)", placeholder="Enter the correct caption for comparison")
],
outputs=gr.Textbox(label="Generated Captions", lines=10),
title="πŸ€– AI Image Captioning",
description="Upload an image and get captions from multiple AI models!",
examples=[
["https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat.jpg", ""],
["https://huggingface.co/datasets/mishig/sample_images/resolve/main/dog.jpg", ""],
]
)
if __name__ == "__main__":
demo.launch()