File size: 5,291 Bytes
f2c2a4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import spaces
import torch
from PIL import Image

# Set random seeds for reproducibility
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

from models.vision_language_model import VisionLanguageModel
from data.processors import get_tokenizer, get_image_processor


@spaces.GPU
def generate_outputs(image, query):
    # Determine device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    # Load model
    hf_model = "lusxvr/nanoVLM-222M"
    try:
        model = VisionLanguageModel.from_pretrained(hf_model).to(device)
        model.eval()
    except Exception as e:
        return f"Error loading model: {str(e)}", None, None, None, None

    # Load tokenizer and image processor
    try:
        tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
        image_processor = get_image_processor(model.cfg.vit_img_size)
    except Exception as e:
        return f"Error loading tokenizer or image processor: {str(e)}", None, None, None, None

    # Prepare text input
    template = f"Question: {query} Answer:"
    encoded = tokenizer.batch_encode_plus([template], return_tensors="pt")
    tokens = encoded["input_ids"].to(device)

    # Process image
    try:
        img = image.convert("RGB")
        img_t = image_processor(img).unsqueeze(0).to(device)
    except Exception as e:
        return f"Error processing image: {str(e)}", None, None, None, None

    # Generate four outputs
    outputs = []
    max_new_tokens = 50  # Fixed value from provided script
    try:
        for _ in range(4):
            gen = model.generate(tokens, img_t, max_new_tokens=max_new_tokens)
            out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
            outputs.append(out)
    except Exception as e:
        return f"Error during generation: {str(e)}", None, None, None, None

    return None, outputs[0], outputs[1], outputs[2], outputs[3]


def main():
    # Define minimal CSS for subtle aesthetic enhancements
    css = """
    .gradio-container {
        font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
        padding: 20px;
    }
    h1 {
        color: #333;
        text-align: center;
        margin-bottom: 20px;
    }
    .description {
        margin-bottom: 20px;
        line-height: 1.6;
    }
    .gr-button {
        padding: 10px 20px;
    }
    """

    # Define Gradio interface
    with gr.Blocks(css=css, title="nanoVLM Image-to-Text Generator") as app:
        gr.Markdown(
            "# nanoVLM Image-to-Text Generator"
        )
        gr.Markdown(
            """
            <div class="description">
                This demo showcases <b>nanoVLM</b>, a lightweight vision-language model by HuggingFace. 
                Upload an image and provide a query to generate four text descriptions. 
                The model is based on the <a href="https://github.com/huggingface/nanoVLM/" target="_blank">nanoVLM repository</a> 
                and uses the pretrained model <a href="https://huggingface.co/lusxvr/nanoVLM-222M" target="_blank">lusxvr/nanoVLM-222M</a>. 
                nanoVLM is designed for efficient image-to-text generation, ideal for resource-constrained environments.
            </div>
            """
        )

        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    type="pil",
                    label="Upload Image",
                    value="cat.jpg"  # Set example image
                )
                query_input = gr.Textbox(
                    label="Query",
                    value="What is this?",
                    placeholder="Enter your query here",
                    lines=2
                )
                submit_button = gr.Button("Generate")

            with gr.Column():
                error_output = gr.Textbox(
                    label="Errors (if any)",
                    placeholder="No errors",
                    visible=True,
                    interactive=False
                )
                output1 = gr.Textbox(
                    label="Generation 1",
                    placeholder="Output 1 will appear here...",
                    lines=3
                )
                output2 = gr.Textbox(
                    label="Generation 2",
                    placeholder="Output 2 will appear here...",
                    lines=3
                )
                output3 = gr.Textbox(
                    label="Generation 3",
                    placeholder="Output 3 will appear here...",
                    lines=3
                )
                output4 = gr.Textbox(
                    label="Generation 4",
                    placeholder="Output 4 will appear here...",
                    lines=3
                )

        # Define action on submit
        submit_button.click(
            fn=generate_outputs,
            inputs=[image_input, query_input],
            outputs=[error_output, output1, output2, output3, output4]
        )

    # Launch the app
    app.launch()


if __name__ == "__main__":
    main()