import torch import os from PIL import Image from transformers import AutoModelForImageClassification, SiglipImageProcessor import gradio as gr import pytesseract # Model path MODEL_PATH = "./model" try: print(f"=== Loading model from: {MODEL_PATH} ===") print(f"Available files: {os.listdir(MODEL_PATH)}") # Load the model (this should work with your files) print("Loading model...") model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True) print("✅ Model loaded successfully!") # Load just the image processor (not the full AutoProcessor) print("Loading image processor...") try: # Try to load the image processor from your local files processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True) print("✅ Image processor loaded from local files!") except Exception as e: print(f"⚠️ Could not load local processor: {e}") print("Loading image processor from base SigLIP model...") # Fallback: load processor from base model online processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224") print("✅ Image processor loaded from base model!") # Get labels from your model config if hasattr(model.config, 'id2label') and model.config.id2label: labels = model.config.id2label print(f"✅ Found {len(labels)} labels in model config") else: # Create generic labels if none exist num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2 labels = {i: f"class_{i}" for i in range(num_labels)} print(f"✅ Created {len(labels)} generic labels") print("🎉 Model setup complete!") except Exception as e: print(f"❌ Error loading model: {e}") print("\n=== Debug Information ===") print(f"Files in model directory: {os.listdir(MODEL_PATH)}") raise def classify_meme(image: Image.Image): """ Classify meme and extract text using OCR """ try: # OCR: extract text from image extracted_text = pytesseract.image_to_string(image) # Process image for the model inputs = processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get predictions predictions = {} for i in range(len(labels)): label = labels.get(i, f"class_{i}") predictions[label] = float(probs[0][i]) # Sort predictions by confidence sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) # Debug prints print("=== Classification Results ===") print(f"Extracted Text: '{extracted_text.strip()}'") print("Top 3 Predictions:") for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]): print(f" {i+1}. {label}: {prob:.4f}") return sorted_predictions, extracted_text.strip() except Exception as e: error_msg = f"Error processing image: {str(e)}" print(f"❌ {error_msg}") return {"Error": 1.0}, error_msg # Create Gradio interface demo = gr.Interface( fn=classify_meme, inputs=gr.Image(type="pil", label="Upload Meme Image"), outputs=[ gr.Label(num_top_classes=5, label="Meme Classification"), gr.Textbox(label="Extracted Text", lines=3) ], title="🎭 Meme Classifier with OCR", description=""" Upload a meme image to: 1. **Classify** its content using your trained SigLIP2_77 model 2. **Extract text** using OCR (Optical Character Recognition) Your model was trained on meme data and will predict the category/sentiment of the uploaded meme. """, examples=None, allow_flagging="never" ) if __name__ == "__main__": print("🚀 Starting Gradio interface...") demo.launch( server_name="0.0.0.0", server_port=7860, share=False )