Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import glob | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForImageClassification | |
import gradio as gr | |
import pytesseract | |
def find_model_files(): | |
"""Find model files in the current directory structure""" | |
print("=== Searching for model files ===") | |
# Look for key model files | |
config_files = glob.glob("**/config.json", recursive=True) | |
model_files = glob.glob("**/pytorch_model.bin", recursive=True) + glob.glob("**/model.safetensors", recursive=True) | |
preprocessor_files = glob.glob("**/preprocessor_config.json", recursive=True) | |
print(f"Found config.json files: {config_files}") | |
print(f"Found model weight files: {model_files}") | |
print(f"Found preprocessor_config.json files: {preprocessor_files}") | |
# Find the directory that contains all necessary files | |
for config_file in config_files: | |
model_dir = os.path.dirname(config_file) | |
if not model_dir: # If config.json is in root | |
model_dir = "." | |
# Check if this directory has all required files | |
has_model = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in model_files) | |
has_preprocessor = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in preprocessor_files) | |
if has_model and has_preprocessor: | |
print(f"Found complete model in directory: {model_dir}") | |
return model_dir | |
elif has_model: | |
print(f"Found model with config but missing preprocessor in: {model_dir}") | |
return model_dir # Try anyway, might work | |
print("No complete model directory found") | |
return None | |
# Search for model files | |
MODEL_PATH = find_model_files() | |
if MODEL_PATH is None: | |
MODEL_PATH = "." # Fallback to current directory | |
print("Falling back to current directory") | |
try: | |
# Load model and processor from detected path | |
print(f"=== Attempting to load model from: {MODEL_PATH} ===") | |
print(f"Current working directory: {os.getcwd()}") | |
# List all files in the detected model directory | |
if MODEL_PATH == ".": | |
print("Files in root directory:") | |
for item in os.listdir("."): | |
if os.path.isfile(item): | |
print(f" File: {item}") | |
else: | |
print(f" Directory: {item}/") | |
try: | |
sub_files = os.listdir(item)[:5] # Show first 5 files | |
print(f" Contains: {sub_files}{'...' if len(os.listdir(item)) > 5 else ''}") | |
except: | |
pass | |
else: | |
print(f"Files in {MODEL_PATH}:") | |
print(f" {os.listdir(MODEL_PATH)}") | |
# Try to load the model | |
print("Loading model...") | |
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True) | |
print("Model loaded successfully!") | |
print("Loading processor...") | |
try: | |
processor = AutoProcessor.from_pretrained(MODEL_PATH, local_files_only=True) | |
print("Processor loaded successfully!") | |
except Exception as proc_error: | |
print(f"Error loading processor from local files: {proc_error}") | |
print("Attempting to load processor from base SigLIP model...") | |
# Try to load processor from the base SigLIP model | |
try: | |
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") | |
print("Processor loaded from base SigLIP model!") | |
except Exception as base_error: | |
print(f"Error loading base processor: {base_error}") | |
print("Trying alternative processor...") | |
# As a last resort, try to create a minimal processor | |
from transformers import CLIPImageProcessor | |
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("Using CLIP processor as fallback!") | |
# Get labels - handle case where id2label might not exist | |
if hasattr(model.config, 'id2label') and model.config.id2label: | |
labels = model.config.id2label | |
else: | |
# Create generic labels if none exist | |
num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 1000 | |
labels = {i: f"class_{i}" for i in range(num_labels)} | |
print(f"Model loaded successfully. Number of classes: {len(labels)}") | |
except Exception as e: | |
print(f"=== ERROR loading model from {MODEL_PATH} ===") | |
print(f"Error: {e}") | |
print("\n=== Debugging Information ===") | |
print("All files in Space:") | |
def list_all_files(directory=".", prefix=""): | |
"""Recursively list all files""" | |
try: | |
items = sorted(os.listdir(directory)) | |
for item in items: | |
item_path = os.path.join(directory, item) | |
if os.path.isfile(item_path): | |
size = os.path.getsize(item_path) | |
print(f"{prefix}π {item} ({size} bytes)") | |
elif os.path.isdir(item_path) and not item.startswith('.'): | |
print(f"{prefix}π {item}/") | |
if len(prefix) < 6: # Limit recursion depth | |
list_all_files(item_path, prefix + " ") | |
except PermissionError: | |
print(f"{prefix}β Permission denied") | |
except Exception as ex: | |
print(f"{prefix}β Error: {ex}") | |
list_all_files() | |
print("\n=== Required Files for Model ===") | |
print("β config.json - Model configuration") | |
print("β pytorch_model.bin OR model.safetensors - Model weights") | |
print("β preprocessor_config.json - Image processor config") | |
print("β tokenizer.json (if applicable) - Tokenizer") | |
print("\n=== Solutions ===") | |
print("1. Make sure all model files are uploaded to your Space") | |
print("2. Check that files aren't corrupted during upload") | |
print("3. Try uploading to a 'model' subfolder") | |
print("4. Verify the model was saved correctly during training") | |
raise | |
# Classify meme and extract text | |
def classify_meme(image: Image.Image): | |
""" | |
Classify meme and extract text using OCR | |
""" | |
try: | |
# OCR: extract text from image | |
extracted_text = pytesseract.image_to_string(image) | |
# Process image with the model | |
inputs = processor(images=image, return_tensors="pt") | |
# Move inputs to same device as model if needed | |
if torch.cuda.is_available() and next(model.parameters()).is_cuda: | |
inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# Get top predictions | |
top_k = min(10, len(labels)) # Show top 10 or all if fewer | |
top_probs, top_indices = torch.topk(probs[0], top_k) | |
predictions = {} | |
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)): | |
label = labels.get(idx.item(), f"class_{idx.item()}") | |
predictions[label] = float(prob) | |
# Debug prints (these will show in the console/logs) | |
print("Extracted Text:", extracted_text.strip()) | |
print("Top Predictions:", predictions) | |
return predictions, extracted_text.strip() | |
except Exception as e: | |
print(f"Error in classification: {e}") | |
return {"Error": 1.0}, f"Error processing image: {str(e)}" | |
# Gradio interface | |
demo = gr.Interface( | |
fn=classify_meme, | |
inputs=gr.Image(type="pil", label="Upload Meme Image"), | |
outputs=[ | |
gr.Label(num_top_classes=5, label="Meme Classification"), | |
gr.Textbox(label="Extracted Text from OCR", lines=3) | |
], | |
title="Meme Classifier with OCR", | |
description=""" | |
Upload a meme image to: | |
1. Classify its content using your trained SigLIP2_77 model | |
2. Extract text using OCR (Optical Character Recognition) | |
Note: Make sure all model files are properly uploaded to your Space. | |
""", | |
examples=None, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
print("Starting Gradio interface...") | |
demo.launch( | |
server_name="0.0.0.0", # Allow external connections in HF Spaces | |
server_port=7860, # Standard port for HF Spaces | |
share=False # HF Spaces handles sharing | |
) |