dalybuilds commited on
Commit
02e998a
Β·
verified Β·
1 Parent(s): e586365

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +162 -144
model_utils.py CHANGED
@@ -1,158 +1,176 @@
1
- import torch
2
- from transformers import ViTImageProcessor, ViTForImageClassification
3
- import numpy as np
4
  from PIL import Image
5
- import cv2
6
 
7
- class BugClassifier:
8
- def __init__(self):
9
- """Initialize the bug classifier with a simple ViT model"""
10
- try:
11
- # Load the model and processor
12
- model_name = "google/vit-base-patch16-224"
13
- self.processor = ViTImageProcessor.from_pretrained(model_name)
14
- self.model = ViTForImageClassification.from_pretrained(
15
- model_name,
16
- num_labels=10,
17
- ignore_mismatched_sizes=True
18
- )
19
-
20
- # Set model to evaluation mode
21
- self.model.eval()
22
-
23
- # Define class labels
24
- self.labels = [
25
- "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
26
- "Japanese Beetle", "Garden Spider", "Green Grasshopper",
27
- "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
28
- ]
29
-
30
- # Species information
31
- self.species_info = {
32
- "Seven-spotted Ladybug": """
33
- The Seven-spotted Ladybug (Coccinella septempunctata) is a beneficial garden insect.
34
- Key characteristics:
35
- - Red wing covers with seven black spots
36
- - Natural pest controller, eating aphids and other small insects
37
- - Typically 7-8mm in length
38
- - Can eat up to 5,000 aphids in their lifetime
39
- """,
40
- "Monarch Butterfly": """
41
- The Monarch Butterfly (Danaus plexippus) is known for its migration patterns.
42
- Key characteristics:
43
- - Orange wings with black veins and white spots
44
- - Wingspan of 93-105mm
45
- - Feeds on milkweed as caterpillars
46
- - Makes annual migrations of up to 3,000 miles
47
- """,
48
- # Add more species info as needed
49
- }
50
-
51
- print("Model initialized successfully")
52
-
53
- except Exception as e:
54
- print(f"Error initializing model: {str(e)}")
55
- raise RuntimeError(f"Failed to initialize BugClassifier: {str(e)}")
56
 
57
- def preprocess_image(self, image):
58
- """Preprocess the image for model input"""
59
- try:
60
- # Ensure image is in RGB format
61
- if image.mode != 'RGB':
62
- image = image.convert('RGB')
63
-
64
- # Process image using the ViT processor
65
- inputs = self.processor(images=image, return_tensors="pt")
66
- return inputs.pixel_values
67
-
68
- except Exception as e:
69
- print(f"Error preprocessing image: {str(e)}")
70
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- def predict(self, image):
73
- """Predict the insect species from an image"""
 
 
 
 
 
 
 
 
 
74
  try:
75
- print("Starting prediction...")
 
 
76
 
77
- # Preprocess image
78
- inputs = self.preprocess_image(image)
79
- print("Image preprocessed successfully")
80
 
81
- # Make prediction
82
- with torch.no_grad():
83
- outputs = self.model(inputs)
84
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
85
- confidence, predicted = torch.max(probabilities, 1)
86
-
87
- confidence_score = confidence.item() * 100
88
- predicted_label = self.labels[predicted.item()]
89
-
90
- print(f"Prediction complete: {predicted_label} with confidence {confidence_score}%")
91
-
92
- # Return Unknown if confidence is too low
93
- if confidence_score < 40:
94
- return "Unknown Insect", confidence_score
95
-
96
- return predicted_label, confidence_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
98
  except Exception as e:
99
- print(f"Prediction error: {str(e)}")
100
- return "Error Processing Image", 0.0
101
 
102
- def get_species_info(self, species):
103
- """Get information about a species"""
104
- default_info = f"""
105
- This appears to be a {species}.
106
- While we don't have detailed information about this specific species,
107
- all insects play important roles in their ecosystems.
108
- Consider taking another photo from a different angle or in better lighting
109
- for more accurate identification.
110
- """
111
- return self.species_info.get(species, default_info)
112
-
113
- def get_gradcam(self, image):
114
- """Generate attention visualization"""
115
- try:
116
- # Basic attention visualization
117
- inputs = self.preprocess_image(image)
118
 
119
- with torch.no_grad():
120
- outputs = self.model(inputs, output_attentions=True)
121
- # Get attention from last layer
122
- attention = outputs.attentions[-1].mean(dim=1).mean(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Process attention map
125
- attention_map = attention.numpy()[0]
126
- attention_map = cv2.resize(attention_map, (224, 224))
127
- attention_map = np.uint8(255 * attention_map)
128
- heatmap = cv2.applyColorMap(attention_map, cv2.COLORMAP_JET)
129
-
130
- # Prepare original image
131
- img_array = np.array(image.resize((224, 224)))
132
-
133
- # Combine heatmap with original image
134
- output = cv2.addWeighted(img_array, 0.7, heatmap, 0.3, 0)
135
-
136
- return Image.fromarray(output)
137
-
138
  except Exception as e:
139
- print(f"Visualization error: {str(e)}")
140
- return image
141
 
142
- def get_severity_prediction(species):
143
- """Get ecological severity prediction"""
144
- severity_map = {
145
- "Seven-spotted Ladybug": "Low",
146
- "Monarch Butterfly": "Low",
147
- "Carpenter Ant": "Medium",
148
- "Japanese Beetle": "High",
149
- "Garden Spider": "Low",
150
- "Green Grasshopper": "Medium",
151
- "Luna Moth": "Low",
152
- "Common Dragonfly": "Low",
153
- "Honey Bee": "Low",
154
- "Paper Wasp": "Medium",
155
- "Unknown Insect": "Unknown",
156
- "Error Processing Image": "Unknown"
157
- }
158
- return severity_map.get(species, "Unknown")
 
1
+ import streamlit as st
 
 
2
  from PIL import Image
3
+ from model_utils import BugClassifier, get_severity_prediction
4
 
5
+ # Page configuration
6
+ st.set_page_config(
7
+ page_title="Bug-O-Scope πŸ”πŸž",
8
+ page_icon="πŸ”",
9
+ layout="wide"
10
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Initialize session state for model
13
+ @st.cache_resource
14
+ def load_model():
15
+ try:
16
+ print("Loading model...")
17
+ model = BugClassifier()
18
+ print("Model loaded successfully")
19
+ return model
20
+ except Exception as e:
21
+ print(f"Error loading model: {str(e)}")
22
+ return None
23
+
24
+ # Ensure model is loaded
25
+ if 'model' not in st.session_state:
26
+ st.session_state.model = load_model()
27
+
28
+ def main():
29
+ # Header
30
+ st.title("Bug-O-Scope πŸ”πŸž")
31
+ st.markdown("""
32
+ Welcome to Bug-O-Scope! Upload a picture of an insect to learn more about it.
33
+ This educational tool helps you identify bugs and understand their role in our ecosystem.
34
+ """)
35
+
36
+ # Sidebar
37
+ st.sidebar.header("About Bug-O-Scope")
38
+ st.sidebar.markdown("""
39
+ Bug-O-Scope is an AI-powered tool that helps you:
40
+ * πŸ” Identify insects from photos
41
+ * πŸ“š Learn about different species
42
+ * 🌍 Understand their ecological impact
43
+ * πŸ”¬ Compare different insects
44
+ """)
45
+
46
+ # Check if model loaded successfully
47
+ if st.session_state.model is None:
48
+ st.error("Error: Model failed to load. Please try refreshing the page.")
49
+ return
50
+
51
+ # Main content
52
+ tab1, tab2 = st.tabs(["Single Bug Analysis", "Bug Comparison"])
53
 
54
+ with tab1:
55
+ single_bug_analysis()
56
+
57
+ with tab2:
58
+ compare_bugs()
59
+
60
+ def single_bug_analysis():
61
+ """Handle single bug analysis"""
62
+ uploaded_file = st.file_uploader("Upload a bug photo", type=['png', 'jpg', 'jpeg'], key="single")
63
+
64
+ if uploaded_file:
65
  try:
66
+ # Load and display image
67
+ image = Image.open(uploaded_file)
68
+ col1, col2 = st.columns(2)
69
 
70
+ with col1:
71
+ st.image(image, caption="Uploaded Image", use_container_width=True)
 
72
 
73
+ with col2:
74
+ with st.spinner("Analyzing your bug..."):
75
+ # Get predictions
76
+ prediction, confidence = st.session_state.model.predict(image)
77
+ print(f"Prediction: {prediction}, Confidence: {confidence}")
78
+
79
+ st.success("Analysis Complete!")
80
+ st.markdown("### Identified Species")
81
+ st.markdown(f"**{prediction}**")
82
+ st.markdown(f"Confidence: {confidence:.2f}%")
83
+
84
+ # Only show ecological impact for known insects
85
+ if prediction != "Unknown Insect" and prediction != "Error Processing Image":
86
+ severity = get_severity_prediction(prediction)
87
+ st.markdown("### Ecological Impact")
88
+ severity_color = {
89
+ "Low": "green",
90
+ "Medium": "orange",
91
+ "High": "red",
92
+ "Unknown": "gray"
93
+ }
94
+ st.markdown(
95
+ f"Severity: <span style='color: {severity_color[severity]}'>{severity}</span>",
96
+ unsafe_allow_html=True
97
+ )
98
+
99
+ # Display species information
100
+ if prediction != "Unknown Insect" and prediction != "Error Processing Image":
101
+ st.markdown("### About This Species")
102
+ species_info = st.session_state.model.get_species_info(prediction)
103
+ st.markdown(species_info)
104
 
105
+ # Display visualization
106
+ st.markdown("### Feature Highlights")
107
+ gradcam = st.session_state.model.get_gradcam(image)
108
+ st.image(gradcam, caption="Important Features", use_container_width=True)
109
+
110
  except Exception as e:
111
+ st.error(f"Error processing image: {str(e)}")
112
+ st.info("Please try uploading a different image.")
113
 
114
+ def compare_bugs():
115
+ """Handle bug comparison"""
116
+ col1, col2 = st.columns(2)
117
+
118
+ with col1:
119
+ file1 = st.file_uploader("Upload first bug photo", type=['png', 'jpg', 'jpeg'], key="compare1")
120
+ if file1:
121
+ try:
122
+ image1 = Image.open(file1)
123
+ st.image(image1, caption="First Bug", use_container_width=True)
124
+ except Exception as e:
125
+ st.error(f"Error loading first image: {str(e)}")
126
+ return
 
 
 
127
 
128
+ with col2:
129
+ file2 = st.file_uploader("Upload second bug photo", type=['png', 'jpg', 'jpeg'], key="compare2")
130
+ if file2:
131
+ try:
132
+ image2 = Image.open(file2)
133
+ st.image(image2, caption="Second Bug", use_container_width=True)
134
+ except Exception as e:
135
+ st.error(f"Error loading second image: {str(e)}")
136
+ return
137
+
138
+ if file1 and file2:
139
+ try:
140
+ with st.spinner("Generating comparison..."):
141
+ # Get predictions
142
+ pred1, conf1 = st.session_state.model.predict(image1)
143
+ pred2, conf2 = st.session_state.model.predict(image2)
144
+
145
+ if pred1 not in ["Unknown Insect", "Error Processing Image"] and \
146
+ pred2 not in ["Unknown Insect", "Error Processing Image"]:
147
+
148
+ # Display results
149
+ st.markdown("### Comparison Results")
150
+ comp_col1, comp_col2 = st.columns(2)
151
+
152
+ with comp_col1:
153
+ st.markdown(f"**Species 1**: {pred1}")
154
+ st.markdown(f"Confidence: {conf1:.2f}%")
155
+ gradcam1 = st.session_state.model.get_gradcam(image1)
156
+ st.image(gradcam1, caption="Feature Highlights - Bug 1", use_container_width=True)
157
+
158
+ with comp_col2:
159
+ st.markdown(f"**Species 2**: {pred2}")
160
+ st.markdown(f"Confidence: {conf2:.2f}%")
161
+ gradcam2 = st.session_state.model.get_gradcam(image2)
162
+ st.image(gradcam2, caption="Feature Highlights - Bug 2", use_container_width=True)
163
+
164
+ # Display comparison
165
+ st.markdown("### Key Differences")
166
+ st.markdown(st.session_state.model.get_species_info(pred1))
167
+ st.markdown(st.session_state.model.get_species_info(pred2))
168
+ else:
169
+ st.warning("Unable to generate meaningful comparison due to low confidence predictions.")
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  except Exception as e:
172
+ st.error(f"Error comparing images: {str(e)}")
173
+ st.info("Please try uploading different images or try again.")
174
 
175
+ if __name__ == "__main__":
176
+ main()