dalybuilds commited on
Commit
f88bb49
·
verified ·
1 Parent(s): 0398bad

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +124 -66
model_utils.py CHANGED
@@ -8,50 +8,73 @@ from scipy.special import softmax
8
 
9
  class BugClassifier:
10
  def __init__(self):
11
- # Initialize model and feature extractor
12
- self.model = ViTForImageClassification.from_pretrained(
13
- "google/vit-base-patch16-224",
14
- num_labels=10, # Match number of classes
15
- ignore_mismatched_sizes=True # Add this to handle size mismatch
16
- )
17
- self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
18
-
19
- # Define class labels
20
- self.labels = [
21
- "Ladybug", "Butterfly", "Ant", "Beetle", "Spider",
22
- "Grasshopper", "Moth", "Dragonfly", "Bee", "Wasp"
23
- ]
24
-
25
- # Species information database
26
- self.species_info = {
27
- "Ladybug": """
28
- Ladybugs are small, round beetles known for their distinctive spotted patterns.
29
- They are beneficial insects that feed on plant-damaging pests like aphids.
30
- Fun fact: The number of spots on a ladybug can indicate its species!
31
- """,
32
- "Butterfly": """
33
- Butterflies are beautiful insects known for their large, colorful wings.
34
- They play a crucial role in pollination and are indicators of ecosystem health.
35
- They undergo complete metamorphosis from caterpillar to adult.
36
- """,
37
- # Add more species information as needed
38
- }
39
-
40
- # Set model to evaluation mode
41
- self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def predict(self, image):
44
- """
45
- Make a prediction on the input image
46
- Returns predicted class and confidence score
47
- """
48
  try:
49
- # Preprocess image
50
- if isinstance(image, Image.Image):
51
- image_tensor = self.preprocess_image(image)
52
- else:
53
  raise ValueError("Input must be a PIL Image")
54
 
 
 
 
55
  # Make prediction
56
  with torch.no_grad():
57
  outputs = self.model(image_tensor)
@@ -60,43 +83,25 @@ class BugClassifier:
60
 
61
  # Ensure index is within bounds
62
  if pred_idx >= len(self.labels):
63
- pred_idx = 0 # Default to first class if out of bounds
64
 
65
  return self.labels[pred_idx], float(probs[pred_idx] * 100)
66
- except Exception as e:
67
- print(f"Prediction error: {str(e)}")
68
- return self.labels[0], 0.0 # Return default prediction in case of error
69
-
70
- def preprocess_image(self, image):
71
- """
72
- Preprocess image for model input
73
- """
74
- try:
75
- # Convert RGBA to RGB if necessary
76
- if image.mode == 'RGBA':
77
- image = image.convert('RGB')
78
 
79
- # Process image using feature extractor
80
- inputs = self.feature_extractor(images=image, return_tensors="pt")
81
- return inputs.pixel_values
82
  except Exception as e:
83
- print(f"Preprocessing error: {str(e)}")
84
- raise
85
 
86
  def get_species_info(self, species):
87
- """
88
- Return information about a species
89
- """
90
- return self.species_info.get(species, f"""
91
  Information about {species}:
92
  This species is part of our insect database. While detailed information
93
  is still being compiled, all insects play important roles in their ecosystems.
94
- """)
 
95
 
96
  def compare_species(self, species1, species2):
97
- """
98
- Generate comparison information between two species
99
- """
100
  info1 = self.get_species_info(species1)
101
  info2 = self.get_species_info(species2)
102
 
@@ -110,4 +115,57 @@ class BugClassifier:
110
  {info2}
111
 
112
  Both species contribute to their ecosystems in unique ways.
113
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class BugClassifier:
10
  def __init__(self):
11
+ try:
12
+ # Initialize model and feature extractor
13
+ self.model = ViTForImageClassification.from_pretrained(
14
+ "google/vit-base-patch16-224",
15
+ num_labels=10,
16
+ ignore_mismatched_sizes=True
17
+ )
18
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
19
+
20
+ # Set model to evaluation mode
21
+ self.model.eval()
22
+
23
+ # Define class labels
24
+ self.labels = [
25
+ "Ladybug", "Butterfly", "Ant", "Beetle", "Spider",
26
+ "Grasshopper", "Moth", "Dragonfly", "Bee", "Wasp"
27
+ ]
28
+
29
+ # Species information database
30
+ self.species_info = {
31
+ "Ladybug": """
32
+ Ladybugs are small, round beetles known for their distinctive spotted patterns.
33
+ They are beneficial insects that feed on plant-damaging pests like aphids.
34
+ Fun fact: The number of spots on a ladybug can indicate its species!
35
+ """,
36
+ "Butterfly": """
37
+ Butterflies are beautiful insects known for their large, colorful wings.
38
+ They play a crucial role in pollination and are indicators of ecosystem health.
39
+ They undergo complete metamorphosis from caterpillar to adult.
40
+ """,
41
+ "Ant": """
42
+ Ants are social insects that live in colonies. They are incredibly strong
43
+ for their size and play vital roles in soil health and ecosystem maintenance.
44
+ """,
45
+ # Add more species information for other classes...
46
+ }
47
+
48
+ except Exception as e:
49
+ raise RuntimeError(f"Error initializing BugClassifier: {str(e)}")
50
+
51
+ def preprocess_image(self, image):
52
+ """Preprocess image for model input"""
53
+ try:
54
+ # Convert RGBA to RGB if necessary
55
+ if image.mode == 'RGBA':
56
+ image = image.convert('RGB')
57
+
58
+ # Resize image if needed
59
+ if image.size != (224, 224):
60
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
61
+
62
+ # Process image using feature extractor
63
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
64
+ return inputs.pixel_values
65
+
66
+ except Exception as e:
67
+ raise ValueError(f"Error preprocessing image: {str(e)}")
68
 
69
  def predict(self, image):
70
+ """Make a prediction on the input image"""
 
 
 
71
  try:
72
+ if not isinstance(image, Image.Image):
 
 
 
73
  raise ValueError("Input must be a PIL Image")
74
 
75
+ # Preprocess image
76
+ image_tensor = self.preprocess_image(image)
77
+
78
  # Make prediction
79
  with torch.no_grad():
80
  outputs = self.model(image_tensor)
 
83
 
84
  # Ensure index is within bounds
85
  if pred_idx >= len(self.labels):
86
+ pred_idx = 0
87
 
88
  return self.labels[pred_idx], float(probs[pred_idx] * 100)
 
 
 
 
 
 
 
 
 
 
 
 
89
 
 
 
 
90
  except Exception as e:
91
+ print(f"Prediction error: {str(e)}")
92
+ return self.labels[0], 0.0
93
 
94
  def get_species_info(self, species):
95
+ """Return information about a species"""
96
+ default_info = f"""
 
 
97
  Information about {species}:
98
  This species is part of our insect database. While detailed information
99
  is still being compiled, all insects play important roles in their ecosystems.
100
+ """
101
+ return self.species_info.get(species, default_info)
102
 
103
  def compare_species(self, species1, species2):
104
+ """Generate comparison information between two species"""
 
 
105
  info1 = self.get_species_info(species1)
106
  info2 = self.get_species_info(species2)
107
 
 
115
  {info2}
116
 
117
  Both species contribute to their ecosystems in unique ways.
118
+ """
119
+
120
+ def get_gradcam(self, image):
121
+ """Generate Grad-CAM visualization for the image"""
122
+ try:
123
+ # Preprocess image
124
+ image_tensor = self.preprocess_image(image)
125
+
126
+ # Get model attention weights (using last layer's attention)
127
+ with torch.no_grad():
128
+ outputs = self.model(image_tensor, output_attentions=True)
129
+ attention = outputs.attentions[-1] # Get last layer's attention
130
+
131
+ # Convert attention to heatmap
132
+ attention_map = attention.mean(dim=1).mean(dim=1).numpy()[0]
133
+
134
+ # Resize attention map to image size
135
+ attention_map = cv2.resize(attention_map, (224, 224))
136
+
137
+ # Normalize attention map
138
+ attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
139
+
140
+ # Convert to heatmap
141
+ heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
142
+
143
+ # Convert original image to RGB numpy array
144
+ original_image = np.array(image.resize((224, 224)))
145
+ if len(original_image.shape) == 2: # Convert grayscale to RGB
146
+ original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
147
+
148
+ # Overlay heatmap on original image
149
+ overlay = cv2.addWeighted(original_image, 0.7, heatmap, 0.3, 0)
150
+
151
+ return Image.fromarray(overlay)
152
+
153
+ except Exception as e:
154
+ print(f"Error generating Grad-CAM: {str(e)}")
155
+ return image # Return original image if Grad-CAM fails
156
+
157
+ def get_severity_prediction(species):
158
+ """Predict ecological severity/impact based on species"""
159
+ severity_map = {
160
+ "Ladybug": "Low",
161
+ "Butterfly": "Low",
162
+ "Ant": "Medium",
163
+ "Beetle": "Medium",
164
+ "Spider": "Low",
165
+ "Grasshopper": "Medium",
166
+ "Moth": "Low",
167
+ "Dragonfly": "Low",
168
+ "Bee": "Low",
169
+ "Wasp": "Medium"
170
+ }
171
+ return severity_map.get(species, "Medium")