import torch import torch.nn.functional as F from transformers import ViTForImageClassification, AutoFeatureExtractor import numpy as np from PIL import Image import cv2 from scipy.special import softmax class BugClassifier: def __init__(self): try: # Initialize model and feature extractor self.model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224", num_labels=10, ignore_mismatched_sizes=True ) self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") # Set model to evaluation mode self.model.eval() # Define class labels self.labels = [ "Ladybug", "Butterfly", "Ant", "Beetle", "Spider", "Grasshopper", "Moth", "Dragonfly", "Bee", "Wasp" ] # Species information database self.species_info = { "Ladybug": """ Ladybugs are small, round beetles known for their distinctive spotted patterns. They are beneficial insects that feed on plant-damaging pests like aphids. Fun fact: The number of spots on a ladybug can indicate its species! """, "Butterfly": """ Butterflies are beautiful insects known for their large, colorful wings. They play a crucial role in pollination and are indicators of ecosystem health. They undergo complete metamorphosis from caterpillar to adult. """, "Ant": """ Ants are social insects that live in colonies. They are incredibly strong for their size and play vital roles in soil health and ecosystem maintenance. """, # Add more species information for other classes... } except Exception as e: raise RuntimeError(f"Error initializing BugClassifier: {str(e)}") def preprocess_image(self, image): """Preprocess image for model input""" try: # Convert RGBA to RGB if necessary if image.mode == 'RGBA': image = image.convert('RGB') # Resize image if needed if image.size != (224, 224): image = image.resize((224, 224), Image.Resampling.LANCZOS) # Process image using feature extractor inputs = self.feature_extractor(images=image, return_tensors="pt") return inputs.pixel_values except Exception as e: raise ValueError(f"Error preprocessing image: {str(e)}") def predict(self, image): """Make a prediction on the input image""" try: if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") # Preprocess image image_tensor = self.preprocess_image(image) # Make prediction with torch.no_grad(): outputs = self.model(image_tensor) probs = F.softmax(outputs.logits, dim=-1).numpy()[0] pred_idx = np.argmax(probs) # Ensure index is within bounds if pred_idx >= len(self.labels): pred_idx = 0 return self.labels[pred_idx], float(probs[pred_idx] * 100) except Exception as e: print(f"Prediction error: {str(e)}") return self.labels[0], 0.0 def get_species_info(self, species): """Return information about a species""" default_info = f""" Information about {species}: This species is part of our insect database. While detailed information is still being compiled, all insects play important roles in their ecosystems. """ return self.species_info.get(species, default_info) def compare_species(self, species1, species2): """Generate comparison information between two species""" info1 = self.get_species_info(species1) info2 = self.get_species_info(species2) return f""" **Comparing {species1} and {species2}:** {species1}: {info1} {species2}: {info2} Both species contribute to their ecosystems in unique ways. """ def get_gradcam(self, image): """Generate Grad-CAM visualization for the image""" try: # Preprocess image image_tensor = self.preprocess_image(image) # Get model attention weights (using last layer's attention) with torch.no_grad(): outputs = self.model(image_tensor, output_attentions=True) attention = outputs.attentions[-1] # Get last layer's attention # Convert attention to heatmap attention_map = attention.mean(dim=1).mean(dim=1).numpy()[0] # Resize attention map to image size attention_map = cv2.resize(attention_map, (224, 224)) # Normalize attention map attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) # Convert to heatmap heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET) # Convert original image to RGB numpy array original_image = np.array(image.resize((224, 224))) if len(original_image.shape) == 2: # Convert grayscale to RGB original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB) # Overlay heatmap on original image overlay = cv2.addWeighted(original_image, 0.7, heatmap, 0.3, 0) return Image.fromarray(overlay) except Exception as e: print(f"Error generating Grad-CAM: {str(e)}") return image # Return original image if Grad-CAM fails def get_severity_prediction(species): """Predict ecological severity/impact based on species""" severity_map = { "Ladybug": "Low", "Butterfly": "Low", "Ant": "Medium", "Beetle": "Medium", "Spider": "Low", "Grasshopper": "Medium", "Moth": "Low", "Dragonfly": "Low", "Bee": "Low", "Wasp": "Medium" } return severity_map.get(species, "Medium")