File size: 5,979 Bytes
e198913
 
 
 
 
 
 
 
 
4670dfa
 
 
e198913
 
 
 
 
 
 
 
 
 
 
 
4670dfa
 
 
 
 
 
 
 
 
 
 
fbe5121
 
 
e198913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4670dfa
 
e198913
 
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
 
 
e198913
 
 
4670dfa
fbe5121
e198913
 
4670dfa
e198913
 
 
fbe5121
 
4670dfa
 
e198913
4670dfa
 
 
 
fbe5121
4670dfa
 
 
 
 
 
 
 
fbe5121
4670dfa
e198913
 
fbe5121
 
4670dfa
 
 
 
e198913
4670dfa
 
 
 
 
 
 
 
1792bb4
fbe5121
 
4670dfa
 
 
fbe5121
e198913
fbe5121
 
e198913
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import sys
import os

# Add the cloned nanoVLM directory to Python's system path
# This allows us to import from the 'models' directory within nanoVLM
NANOVLM_REPO_PATH = "/app/nanoVLM" # Path where we cloned it in Dockerfile
if NANOVLM_REPO_PATH not in sys.path:
    sys.path.insert(0, NANOVLM_REPO_PATH)

import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor # AutoProcessor might still work

# Now import the custom classes from the cloned nanoVLM repository
try:
    from models.vision_language_model import VisionLanguageModel
    from models.configurations import VisionLanguageConfig # Or the specific config class used by nanoVLM
    print("Successfully imported VisionLanguageModel and VisionLanguageConfig from nanoVLM clone.")
except ImportError as e:
    print(f"Error importing from nanoVLM clone: {e}. Check NANOVLM_REPO_PATH and ensure nanoVLM cloned correctly.")
    VisionLanguageModel = None
    VisionLanguageConfig = None


# Determine the device to use
device_choice = os.environ.get("DEVICE", "auto")
if device_choice == "auto":
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = device_choice
print(f"Using device: {device}")

# Load the model and processor
model_id = "lusxvr/nanoVLM-222M"
processor = None
model = None

if VisionLanguageModel and VisionLanguageConfig:
    try:
        print(f"Attempting to load processor for {model_id}")
        # Processor loading might still be okay with AutoProcessor,
        # as processor_config.json is usually standard.
        # trust_remote_code might be needed if processor has custom code too.
        processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        print("Processor loaded.")

        print(f"Attempting to load model config for {model_id} using VisionLanguageConfig")
        # Load the configuration using the custom config class, pointing to your model_id
        # trust_remote_code=True allows it to use any specific code paths from your model_id if needed for config.
        config = VisionLanguageConfig.from_pretrained(model_id, trust_remote_code=True)
        print("Model config loaded.")
        
        print(f"Attempting to load model weights for {model_id} using VisionLanguageModel")
        # Load the model weights using the custom model class and the loaded config
        model = VisionLanguageModel.from_pretrained(model_id, config=config, trust_remote_code=True).to(device)
        print("Model weights loaded successfully.")
        model.eval() # Set to evaluation mode

    except Exception as e:
        print(f"Error loading model, processor, or config: {e}")
        # Fallback if any step fails
        processor = None
        model = None
else:
    print("Custom nanoVLM classes not imported, cannot load model.")


def generate_text_for_image(image_input, prompt_input):
    if model is None or processor is None or not hasattr(model, 'generate'): # Check if model has generate
        return "Error: Model or processor not loaded correctly or model doesn't have 'generate' method. Check logs."

    if image_input is None:
        return "Please upload an image."
    if not prompt_input:
        return "Please provide a prompt."

    try:
        if not isinstance(image_input, Image.Image):
            pil_image = Image.fromarray(image_input)
        else:
            pil_image = image_input
        
        if pil_image.mode != "RGB":
            pil_image = pil_image.convert("RGB")

        # Prepare inputs for the model using the processor
        # The exact format for nanoVLM's custom model might require specific handling.
        # The processor from AutoProcessor should generally work.
        inputs = processor(text=[prompt_input], images=[pil_image], return_tensors="pt").to(device)
        
        # Generate text using the model's generate method
        # Common parameters for generation:
        generated_ids = model.generate(
            inputs['pixel_values'], # Assuming processor output has 'pixel_values'
            inputs['input_ids'],    # Assuming processor output has 'input_ids'
            attention_mask=inputs.get('attention_mask'), # Optional, but good to include
            max_new_tokens=150,
            num_beams=3,
            no_repeat_ngram_size=2,
            early_stopping=True
            # Check nanoVLM's VisionLanguageModel.generate() for specific parameters
        )
        
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        if prompt_input and generated_text.startswith(prompt_input):
             cleaned_text = generated_text[len(prompt_input):].lstrip(" ,.:")
        else:
            cleaned_text = generated_text

        return cleaned_text.strip()

    except Exception as e:
        print(f"Error during generation: {e}")
        return f"An error occurred during text generation: {str(e)}"

description = "Interactive demo for lusxvr/nanoVLM-222M."
example_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
gradio_cache_dir = os.environ.get("GRADIO_TEMP_DIR", "/tmp/gradio_tmp")

iface = gr.Interface(
    fn=generate_text_for_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Your Prompt/Question")
    ],
    outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
    title="Interactive nanoVLM-222M Demo",
    description=description,
    examples=[
        [example_image_url, "a photo of a"],
        [example_image_url, "Describe the image in detail."],
    ],
    cache_examples=True,
    examples_cache_folder=gradio_cache_dir,
    allow_flagging="never"
)

if __name__ == "__main__":
    if model is None or processor is None:
        print("CRITICAL: Model or processor failed to load. Gradio interface may not function correctly.")
    else:
        print("Launching Gradio interface...")
    iface.launch(server_name="0.0.0.0", server_port=7860)