simple Gradio app hosted on Hugging Face Spaces
Browse filessimple Gradio app hosted on Hugging Face Spaces (free version) that uses the fancyfeast/joy-caption-beta-one model for image captioning.
app.py
CHANGED
@@ -1,13 +1,55 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
)
|
13 |
-
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import LlavaForConditionalGeneration, AutoProcessor
|
5 |
+
|
6 |
+
# Load the model and processor
|
7 |
+
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
|
8 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
9 |
+
model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
|
10 |
+
model.eval()
|
11 |
+
|
12 |
+
# Define the captioning function
|
13 |
+
def generate_caption(input_image: Image.Image, caption_type: str = "descriptive", caption_length: str = "medium") -> str:
|
14 |
+
if input_image is None:
|
15 |
+
return "Please upload an image."
|
16 |
+
|
17 |
+
# Prepare the prompt
|
18 |
+
prompt = f"Write a {caption_length} {caption_type} caption for this image."
|
19 |
+
convo = [
|
20 |
+
{
|
21 |
+
"role": "system",
|
22 |
+
"content": "You are a helpful assistant that generates accurate and relevant image captions."
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"role": "user",
|
26 |
+
"content": prompt.strip()
|
27 |
+
}
|
28 |
+
]
|
29 |
+
|
30 |
+
# Process the image and prompt
|
31 |
+
inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
|
33 |
+
# Generate the caption
|
34 |
+
with torch.no_grad():
|
35 |
+
output = model.generate(**inputs, max_new_tokens=100, temperature=0.7, top_p=0.9)
|
36 |
+
|
37 |
+
# Decode the output
|
38 |
+
caption = processor.decode(output[0], skip_special_tokens=True)
|
39 |
+
return caption.strip()
|
40 |
+
|
41 |
+
# Create the Gradio interface
|
42 |
+
interface = gr.Interface(
|
43 |
+
fn=generate_caption,
|
44 |
+
inputs=[
|
45 |
+
gr.Image(label="Upload Image", type="pil"),
|
46 |
+
gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive"),
|
47 |
+
gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
|
48 |
+
],
|
49 |
+
outputs=gr.Textbox(label="Generated Caption"),
|
50 |
+
title="Image Captioning with JoyCaption",
|
51 |
+
description="Upload an image to generate a caption using the fancyfeast/joy-caption-beta-one model."
|
52 |
)
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
interface.launch()
|