File size: 5,291 Bytes
f2c2a4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import spaces
import torch
from PIL import Image
# Set random seeds for reproducibility
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
from models.vision_language_model import VisionLanguageModel
from data.processors import get_tokenizer, get_image_processor
@spaces.GPU
def generate_outputs(image, query):
# Determine device
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# Load model
hf_model = "lusxvr/nanoVLM-222M"
try:
model = VisionLanguageModel.from_pretrained(hf_model).to(device)
model.eval()
except Exception as e:
return f"Error loading model: {str(e)}", None, None, None, None
# Load tokenizer and image processor
try:
tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
image_processor = get_image_processor(model.cfg.vit_img_size)
except Exception as e:
return f"Error loading tokenizer or image processor: {str(e)}", None, None, None, None
# Prepare text input
template = f"Question: {query} Answer:"
encoded = tokenizer.batch_encode_plus([template], return_tensors="pt")
tokens = encoded["input_ids"].to(device)
# Process image
try:
img = image.convert("RGB")
img_t = image_processor(img).unsqueeze(0).to(device)
except Exception as e:
return f"Error processing image: {str(e)}", None, None, None, None
# Generate four outputs
outputs = []
max_new_tokens = 50 # Fixed value from provided script
try:
for _ in range(4):
gen = model.generate(tokens, img_t, max_new_tokens=max_new_tokens)
out = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
outputs.append(out)
except Exception as e:
return f"Error during generation: {str(e)}", None, None, None, None
return None, outputs[0], outputs[1], outputs[2], outputs[3]
def main():
# Define minimal CSS for subtle aesthetic enhancements
css = """
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
padding: 20px;
}
h1 {
color: #333;
text-align: center;
margin-bottom: 20px;
}
.description {
margin-bottom: 20px;
line-height: 1.6;
}
.gr-button {
padding: 10px 20px;
}
"""
# Define Gradio interface
with gr.Blocks(css=css, title="nanoVLM Image-to-Text Generator") as app:
gr.Markdown(
"# nanoVLM Image-to-Text Generator"
)
gr.Markdown(
"""
<div class="description">
This demo showcases <b>nanoVLM</b>, a lightweight vision-language model by HuggingFace.
Upload an image and provide a query to generate four text descriptions.
The model is based on the <a href="https://github.com/huggingface/nanoVLM/" target="_blank">nanoVLM repository</a>
and uses the pretrained model <a href="https://huggingface.co/lusxvr/nanoVLM-222M" target="_blank">lusxvr/nanoVLM-222M</a>.
nanoVLM is designed for efficient image-to-text generation, ideal for resource-constrained environments.
</div>
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Image",
value="cat.jpg" # Set example image
)
query_input = gr.Textbox(
label="Query",
value="What is this?",
placeholder="Enter your query here",
lines=2
)
submit_button = gr.Button("Generate")
with gr.Column():
error_output = gr.Textbox(
label="Errors (if any)",
placeholder="No errors",
visible=True,
interactive=False
)
output1 = gr.Textbox(
label="Generation 1",
placeholder="Output 1 will appear here...",
lines=3
)
output2 = gr.Textbox(
label="Generation 2",
placeholder="Output 2 will appear here...",
lines=3
)
output3 = gr.Textbox(
label="Generation 3",
placeholder="Output 3 will appear here...",
lines=3
)
output4 = gr.Textbox(
label="Generation 4",
placeholder="Output 4 will appear here...",
lines=3
)
# Define action on submit
submit_button.click(
fn=generate_outputs,
inputs=[image_input, query_input],
outputs=[error_output, output1, output2, output3, output4]
)
# Launch the app
app.launch()
if __name__ == "__main__":
main() |