Spaces:
Sleeping
Sleeping
File size: 9,110 Bytes
b1779fd 6a6e076 b1779fd 131383f 6a6e076 cee92ec 6a6e076 cee92ec 6a6e076 23ad95f f70931b 23ad95f f70931b 23ad95f f70931b 23ad95f f70931b 6a6e076 b1779fd 6554f18 b1779fd 6a6e076 6554f18 b1779fd 6a6e076 3711151 6a6e076 3711151 6a6e076 b1779fd 6a6e076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import torch
import os
import glob
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageClassification
import gradio as gr
import pytesseract
def find_model_files():
"""Find model files in the current directory structure"""
print("=== Searching for model files ===")
# Look for key model files
config_files = glob.glob("**/config.json", recursive=True)
model_files = glob.glob("**/pytorch_model.bin", recursive=True) + glob.glob("**/model.safetensors", recursive=True)
preprocessor_files = glob.glob("**/preprocessor_config.json", recursive=True)
print(f"Found config.json files: {config_files}")
print(f"Found model weight files: {model_files}")
print(f"Found preprocessor_config.json files: {preprocessor_files}")
# Find the directory that contains all necessary files
for config_file in config_files:
model_dir = os.path.dirname(config_file)
if not model_dir: # If config.json is in root
model_dir = "."
# Check if this directory has all required files
has_model = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in model_files)
has_preprocessor = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in preprocessor_files)
if has_model and has_preprocessor:
print(f"Found complete model in directory: {model_dir}")
return model_dir
elif has_model:
print(f"Found model with config but missing preprocessor in: {model_dir}")
return model_dir # Try anyway, might work
print("No complete model directory found")
return None
# Search for model files
MODEL_PATH = find_model_files()
if MODEL_PATH is None:
MODEL_PATH = "." # Fallback to current directory
print("Falling back to current directory")
try:
# Load model and processor from detected path
print(f"=== Attempting to load model from: {MODEL_PATH} ===")
print(f"Current working directory: {os.getcwd()}")
# List all files in the detected model directory
if MODEL_PATH == ".":
print("Files in root directory:")
for item in os.listdir("."):
if os.path.isfile(item):
print(f" File: {item}")
else:
print(f" Directory: {item}/")
try:
sub_files = os.listdir(item)[:5] # Show first 5 files
print(f" Contains: {sub_files}{'...' if len(os.listdir(item)) > 5 else ''}")
except:
pass
else:
print(f"Files in {MODEL_PATH}:")
print(f" {os.listdir(MODEL_PATH)}")
# Try to load the model
print("Loading model...")
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
print("Model loaded successfully!")
print("Loading processor...")
try:
processor = AutoProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
print("Processor loaded successfully!")
except Exception as proc_error:
print(f"Error loading processor from local files: {proc_error}")
print("Attempting to load just the image processor...")
# Try to load just the image processor from your model
try:
from transformers import SiglipImageProcessor
processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
print("Image processor loaded successfully from local files!")
except Exception as img_proc_error:
print(f"Error loading local image processor: {img_proc_error}")
print("Loading image processor from base SigLIP model...")
# Try to load processor from the base SigLIP model
try:
from transformers import SiglipImageProcessor
processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224")
print("Image processor loaded from base SigLIP model!")
except Exception as base_error:
print(f"Error loading base processor: {base_error}")
print("Using CLIP processor as fallback...")
# As a last resort, try to create a minimal processor
from transformers import CLIPImageProcessor
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
print("Using CLIP processor as fallback!")
# Get labels - handle case where id2label might not exist
if hasattr(model.config, 'id2label') and model.config.id2label:
labels = model.config.id2label
else:
# Create generic labels if none exist
num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 1000
labels = {i: f"class_{i}" for i in range(num_labels)}
print(f"Model loaded successfully. Number of classes: {len(labels)}")
except Exception as e:
print(f"=== ERROR loading model from {MODEL_PATH} ===")
print(f"Error: {e}")
print("\n=== Debugging Information ===")
print("All files in Space:")
def list_all_files(directory=".", prefix=""):
"""Recursively list all files"""
try:
items = sorted(os.listdir(directory))
for item in items:
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
size = os.path.getsize(item_path)
print(f"{prefix}π {item} ({size} bytes)")
elif os.path.isdir(item_path) and not item.startswith('.'):
print(f"{prefix}π {item}/")
if len(prefix) < 6: # Limit recursion depth
list_all_files(item_path, prefix + " ")
except PermissionError:
print(f"{prefix}β Permission denied")
except Exception as ex:
print(f"{prefix}β Error: {ex}")
list_all_files()
print("\n=== Required Files for Model ===")
print("β
config.json - Model configuration")
print("β
pytorch_model.bin OR model.safetensors - Model weights")
print("β
preprocessor_config.json - Image processor config")
print("β
tokenizer.json (if applicable) - Tokenizer")
print("\n=== Solutions ===")
print("1. Make sure all model files are uploaded to your Space")
print("2. Check that files aren't corrupted during upload")
print("3. Try uploading to a 'model' subfolder")
print("4. Verify the model was saved correctly during training")
raise
# Classify meme and extract text
def classify_meme(image: Image.Image):
"""
Classify meme and extract text using OCR
"""
try:
# OCR: extract text from image
extracted_text = pytesseract.image_to_string(image)
# Process image with the model
inputs = processor(images=image, return_tensors="pt")
# Move inputs to same device as model if needed
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
inputs = {k: v.to('cuda') for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get top predictions
top_k = min(10, len(labels)) # Show top 10 or all if fewer
top_probs, top_indices = torch.topk(probs[0], top_k)
predictions = {}
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
label = labels.get(idx.item(), f"class_{idx.item()}")
predictions[label] = float(prob)
# Debug prints (these will show in the console/logs)
print("Extracted Text:", extracted_text.strip())
print("Top Predictions:", predictions)
return predictions, extracted_text.strip()
except Exception as e:
print(f"Error in classification: {e}")
return {"Error": 1.0}, f"Error processing image: {str(e)}"
# Gradio interface
demo = gr.Interface(
fn=classify_meme,
inputs=gr.Image(type="pil", label="Upload Meme Image"),
outputs=[
gr.Label(num_top_classes=5, label="Meme Classification"),
gr.Textbox(label="Extracted Text from OCR", lines=3)
],
title="Meme Classifier with OCR",
description="""
Upload a meme image to:
1. Classify its content using your trained SigLIP2_77 model
2. Extract text using OCR (Optical Character Recognition)
Note: Make sure all model files are properly uploaded to your Space.
""",
examples=None,
allow_flagging="never"
)
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.launch(
server_name="0.0.0.0", # Allow external connections in HF Spaces
server_port=7860, # Standard port for HF Spaces
share=False # HF Spaces handles sharing
) |