Thesis_CLIP / app.py
Chanlefe's picture
Update app.py
23ad95f verified
raw
history blame
8.48 kB
import torch
import os
import glob
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageClassification
import gradio as gr
import pytesseract
def find_model_files():
"""Find model files in the current directory structure"""
print("=== Searching for model files ===")
# Look for key model files
config_files = glob.glob("**/config.json", recursive=True)
model_files = glob.glob("**/pytorch_model.bin", recursive=True) + glob.glob("**/model.safetensors", recursive=True)
preprocessor_files = glob.glob("**/preprocessor_config.json", recursive=True)
print(f"Found config.json files: {config_files}")
print(f"Found model weight files: {model_files}")
print(f"Found preprocessor_config.json files: {preprocessor_files}")
# Find the directory that contains all necessary files
for config_file in config_files:
model_dir = os.path.dirname(config_file)
if not model_dir: # If config.json is in root
model_dir = "."
# Check if this directory has all required files
has_model = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in model_files)
has_preprocessor = any(os.path.dirname(f) == model_dir or (not os.path.dirname(f) and model_dir == ".") for f in preprocessor_files)
if has_model and has_preprocessor:
print(f"Found complete model in directory: {model_dir}")
return model_dir
elif has_model:
print(f"Found model with config but missing preprocessor in: {model_dir}")
return model_dir # Try anyway, might work
print("No complete model directory found")
return None
# Search for model files
MODEL_PATH = find_model_files()
if MODEL_PATH is None:
MODEL_PATH = "." # Fallback to current directory
print("Falling back to current directory")
try:
# Load model and processor from detected path
print(f"=== Attempting to load model from: {MODEL_PATH} ===")
print(f"Current working directory: {os.getcwd()}")
# List all files in the detected model directory
if MODEL_PATH == ".":
print("Files in root directory:")
for item in os.listdir("."):
if os.path.isfile(item):
print(f" File: {item}")
else:
print(f" Directory: {item}/")
try:
sub_files = os.listdir(item)[:5] # Show first 5 files
print(f" Contains: {sub_files}{'...' if len(os.listdir(item)) > 5 else ''}")
except:
pass
else:
print(f"Files in {MODEL_PATH}:")
print(f" {os.listdir(MODEL_PATH)}")
# Try to load the model
print("Loading model...")
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True)
print("Model loaded successfully!")
print("Loading processor...")
try:
processor = AutoProcessor.from_pretrained(MODEL_PATH, local_files_only=True)
print("Processor loaded successfully!")
except Exception as proc_error:
print(f"Error loading processor from local files: {proc_error}")
print("Attempting to load processor from base SigLIP model...")
# Try to load processor from the base SigLIP model
try:
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
print("Processor loaded from base SigLIP model!")
except Exception as base_error:
print(f"Error loading base processor: {base_error}")
print("Trying alternative processor...")
# As a last resort, try to create a minimal processor
from transformers import CLIPImageProcessor
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
print("Using CLIP processor as fallback!")
# Get labels - handle case where id2label might not exist
if hasattr(model.config, 'id2label') and model.config.id2label:
labels = model.config.id2label
else:
# Create generic labels if none exist
num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 1000
labels = {i: f"class_{i}" for i in range(num_labels)}
print(f"Model loaded successfully. Number of classes: {len(labels)}")
except Exception as e:
print(f"=== ERROR loading model from {MODEL_PATH} ===")
print(f"Error: {e}")
print("\n=== Debugging Information ===")
print("All files in Space:")
def list_all_files(directory=".", prefix=""):
"""Recursively list all files"""
try:
items = sorted(os.listdir(directory))
for item in items:
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
size = os.path.getsize(item_path)
print(f"{prefix}πŸ“„ {item} ({size} bytes)")
elif os.path.isdir(item_path) and not item.startswith('.'):
print(f"{prefix}πŸ“ {item}/")
if len(prefix) < 6: # Limit recursion depth
list_all_files(item_path, prefix + " ")
except PermissionError:
print(f"{prefix}❌ Permission denied")
except Exception as ex:
print(f"{prefix}❌ Error: {ex}")
list_all_files()
print("\n=== Required Files for Model ===")
print("βœ… config.json - Model configuration")
print("βœ… pytorch_model.bin OR model.safetensors - Model weights")
print("βœ… preprocessor_config.json - Image processor config")
print("βœ… tokenizer.json (if applicable) - Tokenizer")
print("\n=== Solutions ===")
print("1. Make sure all model files are uploaded to your Space")
print("2. Check that files aren't corrupted during upload")
print("3. Try uploading to a 'model' subfolder")
print("4. Verify the model was saved correctly during training")
raise
# Classify meme and extract text
def classify_meme(image: Image.Image):
"""
Classify meme and extract text using OCR
"""
try:
# OCR: extract text from image
extracted_text = pytesseract.image_to_string(image)
# Process image with the model
inputs = processor(images=image, return_tensors="pt")
# Move inputs to same device as model if needed
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
inputs = {k: v.to('cuda') for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get top predictions
top_k = min(10, len(labels)) # Show top 10 or all if fewer
top_probs, top_indices = torch.topk(probs[0], top_k)
predictions = {}
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
label = labels.get(idx.item(), f"class_{idx.item()}")
predictions[label] = float(prob)
# Debug prints (these will show in the console/logs)
print("Extracted Text:", extracted_text.strip())
print("Top Predictions:", predictions)
return predictions, extracted_text.strip()
except Exception as e:
print(f"Error in classification: {e}")
return {"Error": 1.0}, f"Error processing image: {str(e)}"
# Gradio interface
demo = gr.Interface(
fn=classify_meme,
inputs=gr.Image(type="pil", label="Upload Meme Image"),
outputs=[
gr.Label(num_top_classes=5, label="Meme Classification"),
gr.Textbox(label="Extracted Text from OCR", lines=3)
],
title="Meme Classifier with OCR",
description="""
Upload a meme image to:
1. Classify its content using your trained SigLIP2_77 model
2. Extract text using OCR (Optical Character Recognition)
Note: Make sure all model files are properly uploaded to your Space.
""",
examples=None,
allow_flagging="never"
)
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.launch(
server_name="0.0.0.0", # Allow external connections in HF Spaces
server_port=7860, # Standard port for HF Spaces
share=False # HF Spaces handles sharing
)