File size: 2,464 Bytes
2ea5846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""

Hugging Face Spaces App

Deploy this to HF Spaces for free hosting

"""

import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch

# Load models
print("Loading models...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

git_processor = AutoProcessor.from_pretrained("microsoft/git-base")
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

def generate_captions(image, true_caption=""):
    """Generate captions using multiple models"""
    if image is None:
        return "Please upload an image first."
    
    results = []
    
    # BLIP model
    try:
        inputs = blip_processor(image, return_tensors="pt")
        out = blip_model.generate(**inputs, max_length=50)
        blip_caption = blip_processor.decode(out[0], skip_special_tokens=True)
        results.append(f"**BLIP:** {blip_caption}")
    except Exception as e:
        results.append(f"**BLIP:** Error - {str(e)}")
    
    # GIT model
    try:
        inputs = git_processor(images=image, return_tensors="pt")
        generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=50)
        git_caption = git_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        results.append(f"**GIT:** {git_caption}")
    except Exception as e:
        results.append(f"**GIT:** Error - {str(e)}")
    
    if true_caption:
        results.insert(0, f"**True Caption:** {true_caption}")
    
    return "\n\n".join(results)

# Create Gradio interface
demo = gr.Interface(
    fn=generate_captions,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="True Caption (Optional)", placeholder="Enter the correct caption for comparison")
    ],
    outputs=gr.Textbox(label="Generated Captions", lines=10),
    title="🤖 AI Image Captioning",
    description="Upload an image and get captions from multiple AI models!",
    examples=[
        ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat.jpg", ""],
        ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/dog.jpg", ""],
    ]
)

if __name__ == "__main__":
    demo.launch()