dalybuilds commited on
Commit
580daa1
·
verified ·
1 Parent(s): 117f15d

Create model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +127 -0
model_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import ViTForImageClassification, AutoFeatureExtractor
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ from scipy.special import softmax
8
+
9
+ class BugClassifier:
10
+ def __init__(self):
11
+ # Initialize model and feature extractor
12
+ self.model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
13
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
14
+
15
+ # Define class labels (these would be replaced with your actual trained classes)
16
+ self.labels = [
17
+ "Ladybug", "Butterfly", "Ant", "Beetle", "Spider",
18
+ "Grasshopper", "Moth", "Dragonfly", "Bee", "Wasp"
19
+ ]
20
+
21
+ # Species information database
22
+ self.species_info = {
23
+ "Ladybug": """
24
+ Ladybugs are small, round beetles known for their distinctive spotted patterns.
25
+ They are beneficial insects that feed on plant-damaging pests like aphids.
26
+ Fun fact: The number of spots on a ladybug can indicate its species!
27
+ """,
28
+ "Butterfly": """
29
+ Butterflies are beautiful insects known for their large, colorful wings.
30
+ They play a crucial role in pollination and are indicators of ecosystem health.
31
+ They undergo complete metamorphosis from caterpillar to adult.
32
+ """,
33
+ # Add more species information as needed
34
+ }
35
+
36
+ def predict(self, image):
37
+ """
38
+ Make a prediction on the input image
39
+ Returns predicted class and confidence score
40
+ """
41
+ # Preprocess image
42
+ if isinstance(image, Image.Image):
43
+ image_tensor = self.preprocess_image(image)
44
+ else:
45
+ raise ValueError("Input must be a PIL Image")
46
+
47
+ # Make prediction
48
+ with torch.no_grad():
49
+ outputs = self.model(image_tensor)
50
+ probs = softmax(outputs.logits.numpy()[0])
51
+ pred_idx = np.argmax(probs)
52
+
53
+ return self.labels[pred_idx], float(probs[pred_idx] * 100)
54
+
55
+ def preprocess_image(self, image):
56
+ """
57
+ Preprocess image for model input
58
+ """
59
+ # Resize image if needed
60
+ if image.size != (224, 224):
61
+ image = image.resize((224, 224))
62
+
63
+ # Convert to tensor using feature extractor
64
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
65
+ return inputs.pixel_values
66
+
67
+ def get_species_info(self, species):
68
+ """
69
+ Return information about a species
70
+ """
71
+ return self.species_info.get(species, "Information not available for this species.")
72
+
73
+ def compare_species(self, species1, species2):
74
+ """
75
+ Generate comparison information between two species
76
+ """
77
+ # This would be expanded with actual comparison logic
78
+ return f"""
79
+ **Comparing {species1} and {species2}:**
80
+
81
+ These species have different characteristics and roles in the ecosystem.
82
+ {self.get_species_info(species1)}
83
+
84
+ {self.get_species_info(species2)}
85
+ """
86
+
87
+ def generate_gradcam(image, model):
88
+ """
89
+ Generate Grad-CAM visualization for the image
90
+ """
91
+ # This is a simplified version - you would need to implement the actual Grad-CAM logic
92
+ # For now, we'll return a simple heatmap overlay
93
+ img_array = np.array(image)
94
+ heatmap = cv2.applyColorMap(
95
+ cv2.resize(np.random.rand(7,7) * 255, (224, 224)).astype(np.uint8),
96
+ cv2.COLORMAP_JET
97
+ )
98
+
99
+ # Overlay heatmap on original image
100
+ overlay = cv2.addWeighted(
101
+ cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR),
102
+ 0.7,
103
+ heatmap,
104
+ 0.3,
105
+ 0
106
+ )
107
+
108
+ return Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
109
+
110
+ def get_severity_prediction(species):
111
+ """
112
+ Predict ecological severity/impact based on species
113
+ """
114
+ # This would be replaced with actual severity prediction logic
115
+ severity_map = {
116
+ "Ladybug": "Low",
117
+ "Butterfly": "Low",
118
+ "Ant": "Medium",
119
+ "Beetle": "Medium",
120
+ "Spider": "Low",
121
+ "Grasshopper": "Medium",
122
+ "Moth": "Low",
123
+ "Dragonfly": "Low",
124
+ "Bee": "Low",
125
+ "Wasp": "Medium"
126
+ }
127
+ return severity_map.get(species, "Medium")