File size: 9,110 Bytes
b1779fd
6a6e076
 
b1779fd
 
 
131383f
 
6a6e076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cee92ec
6a6e076
 
 
 
 
cee92ec
6a6e076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ad95f
 
 
 
 
f70931b
23ad95f
f70931b
23ad95f
f70931b
 
 
 
 
 
23ad95f
f70931b
 
 
 
 
 
 
 
 
 
 
 
 
6a6e076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1779fd
6554f18
b1779fd
6a6e076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6554f18
b1779fd
 
 
6a6e076
3711151
6a6e076
 
3711151
 
6a6e076
 
 
 
 
 
 
 
 
b1779fd
 
 
6a6e076
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import torch
import os
import glob
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageClassification
import gradio as gr
import pytesseract

def find_model_files():
    """Find model files in the current directory structure"""
    print("=== Searching for model files ===")
    
    # Look for key model files
    config_files = glob.glob("**/config.json", recursive=True)
    model_files = glob.glob("**/pytorch_model.bin", recursive=True) + glob.glob("**/model.safetensors", recursive=True)
    preprocessor_files = glob.glob("**/preprocessor_config.json", recursive=True)
    
    print(f"Found config.json files: {config_files}")
    print(f"Found model weight files: {model_files}")
    print(f"Found preprocessor_config.json files: {preprocessor_files}")
    
    # Find the directory that contains all necessary files
    for config_file in config_files:
        model_dir = os.path.dirname(config_file)
        if not model_dir:  # If config.json is in root
            model_dir = "."
        
        # Check if this directory has all required files
        has_model = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in model_files)
        has_preprocessor = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in preprocessor_files)
        
        if has_model and has_preprocessor:
            print(f"Found complete model in directory: {model_dir}")
            return model_dir
        elif has_model:
            print(f"Found model with config but missing preprocessor in: {model_dir}")
            return model_dir  # Try anyway, might work
    
    print("No complete model directory found")
    return None

# Search for model files
MODEL_PATH = find_model_files()
if MODEL_PATH is None:
    MODEL_PATH = "."  # Fallback to current directory
    print("Falling back to current directory")

try:
    # Load model and processor from detected path
    print(f"=== Attempting to load model from: {MODEL_PATH} ===")
    print(f"Current working directory: {os.getcwd()}")
    
    # List all files in the detected model directory
    if MODEL_PATH == ".":
        print("Files in root directory:")
        for item in os.listdir("."):
            if os.path.isfile(item):
                print(f"  File: {item}")
            else:
                print(f"  Directory: {item}/")
                try:
                    sub_files = os.listdir(item)[:5]  # Show first 5 files
                    print(f"    Contains: {sub_files}{'...' if len(os.listdir(item)) > 5 else ''}")
                except:
                    pass
    else:
        print(f"Files in {MODEL_PATH}:")
        print(f"  {os.listdir(MODEL_PATH)}")
    
    # Try to load the model
    print("Loading model...")
    model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
    print("Model loaded successfully!")
    
    print("Loading processor...")
    try:
        processor = AutoProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
        print("Processor loaded successfully!")
    except Exception as proc_error:
        print(f"Error loading processor from local files: {proc_error}")
        print("Attempting to load just the image processor...")
        
        # Try to load just the image processor from your model
        try:
            from transformers import SiglipImageProcessor
            processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
            print("Image processor loaded successfully from local files!")
        except Exception as img_proc_error:
            print(f"Error loading local image processor: {img_proc_error}")
            print("Loading image processor from base SigLIP model...")
            
            # Try to load processor from the base SigLIP model
            try:
                from transformers import SiglipImageProcessor
                processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
                print("Image processor loaded from base SigLIP model!")
            except Exception as base_error:
                print(f"Error loading base processor: {base_error}")
                print("Using CLIP processor as fallback...")
                
                # As a last resort, try to create a minimal processor
                from transformers import CLIPImageProcessor
                processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
                print("Using CLIP processor as fallback!")
    
    # Get labels - handle case where id2label might not exist
    if hasattr(model.config, 'id2label') and model.config.id2label:
        labels = model.config.id2label
    else:
        # Create generic labels if none exist
        num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 1000
        labels = {i: f"class_{i}" for i in range(num_labels)}
    
    print(f"Model loaded successfully. Number of classes: {len(labels)}")
    
except Exception as e:
    print(f"=== ERROR loading model from {MODEL_PATH} ===")
    print(f"Error: {e}")
    print("\n=== Debugging Information ===")
    print("All files in Space:")
    
    def list_all_files(directory=".", prefix=""):
        """Recursively list all files"""
        try:
            items = sorted(os.listdir(directory))
            for item in items:
                item_path = os.path.join(directory, item)
                if os.path.isfile(item_path):
                    size = os.path.getsize(item_path)
                    print(f"{prefix}πŸ“„ {item} ({size} bytes)")
                elif os.path.isdir(item_path) and not item.startswith('.'):
                    print(f"{prefix}πŸ“ {item}/")
                    if len(prefix) < 6:  # Limit recursion depth
                        list_all_files(item_path, prefix + "  ")
        except PermissionError:
            print(f"{prefix}❌ Permission denied")
        except Exception as ex:
            print(f"{prefix}❌ Error: {ex}")
    
    list_all_files()
    
    print("\n=== Required Files for Model ===")
    print("βœ… config.json - Model configuration")
    print("βœ… pytorch_model.bin OR model.safetensors - Model weights") 
    print("βœ… preprocessor_config.json - Image processor config")
    print("βœ… tokenizer.json (if applicable) - Tokenizer")
    
    print("\n=== Solutions ===")
    print("1. Make sure all model files are uploaded to your Space")
    print("2. Check that files aren't corrupted during upload")
    print("3. Try uploading to a 'model' subfolder")
    print("4. Verify the model was saved correctly during training")
    
    raise

# Classify meme and extract text
def classify_meme(image: Image.Image):
    """
    Classify meme and extract text using OCR
    """
    try:
        # OCR: extract text from image
        extracted_text = pytesseract.image_to_string(image)
        
        # Process image with the model
        inputs = processor(images=image, return_tensors="pt")
        
        # Move inputs to same device as model if needed
        if torch.cuda.is_available() and next(model.parameters()).is_cuda:
            inputs = {k: v.to('cuda') for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        
        # Get top predictions
        top_k = min(10, len(labels))  # Show top 10 or all if fewer
        top_probs, top_indices = torch.topk(probs[0], top_k)
        
        predictions = {}
        for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
            label = labels.get(idx.item(), f"class_{idx.item()}")
            predictions[label] = float(prob)
        
        # Debug prints (these will show in the console/logs)
        print("Extracted Text:", extracted_text.strip())
        print("Top Predictions:", predictions)
        
        return predictions, extracted_text.strip()
        
    except Exception as e:
        print(f"Error in classification: {e}")
        return {"Error": 1.0}, f"Error processing image: {str(e)}"

# Gradio interface
demo = gr.Interface(
    fn=classify_meme,
    inputs=gr.Image(type="pil", label="Upload Meme Image"),
    outputs=[
        gr.Label(num_top_classes=5, label="Meme Classification"),
        gr.Textbox(label="Extracted Text from OCR", lines=3)
    ],
    title="Meme Classifier with OCR",
    description="""
    Upload a meme image to:
    1. Classify its content using your trained SigLIP2_77 model
    2. Extract text using OCR (Optical Character Recognition)
    
    Note: Make sure all model files are properly uploaded to your Space.
    """,
    examples=None,
    allow_flagging="never"
)

if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.launch(
        server_name="0.0.0.0",  # Allow external connections in HF Spaces
        server_port=7860,       # Standard port for HF Spaces
        share=False             # HF Spaces handles sharing
    )