dalybuilds commited on
Commit
a4102d1
Β·
verified Β·
1 Parent(s): 8c6d601

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +37 -172
model_utils.py CHANGED
@@ -1,176 +1,41 @@
1
- import streamlit as st
 
 
2
  from PIL import Image
3
- from model_utils import BugClassifier, get_severity_prediction
4
-
5
- # Page configuration
6
- st.set_page_config(
7
- page_title="Bug-O-Scope πŸ”πŸž",
8
- page_icon="πŸ”",
9
- layout="wide"
10
- )
11
-
12
- # Initialize session state for model
13
- @st.cache_resource
14
- def load_model():
15
- try:
16
- print("Loading model...")
17
- model = BugClassifier()
18
- print("Model loaded successfully")
19
- return model
20
- except Exception as e:
21
- print(f"Error loading model: {str(e)}")
22
- return None
23
-
24
- # Ensure model is loaded
25
- if 'model' not in st.session_state:
26
- st.session_state.model = load_model()
27
-
28
- def main():
29
- # Header
30
- st.title("Bug-O-Scope πŸ”πŸž")
31
- st.markdown("""
32
- Welcome to Bug-O-Scope! Upload a picture of an insect to learn more about it.
33
- This educational tool helps you identify bugs and understand their role in our ecosystem.
34
- """)
35
-
36
- # Sidebar
37
- st.sidebar.header("About Bug-O-Scope")
38
- st.sidebar.markdown("""
39
- Bug-O-Scope is an AI-powered tool that helps you:
40
- * πŸ” Identify insects from photos
41
- * πŸ“š Learn about different species
42
- * 🌍 Understand their ecological impact
43
- * πŸ”¬ Compare different insects
44
- """)
45
-
46
- # Check if model loaded successfully
47
- if st.session_state.model is None:
48
- st.error("Error: Model failed to load. Please try refreshing the page.")
49
- return
50
-
51
- # Main content
52
- tab1, tab2 = st.tabs(["Single Bug Analysis", "Bug Comparison"])
53
-
54
- with tab1:
55
- single_bug_analysis()
56
-
57
- with tab2:
58
- compare_bugs()
59
-
60
- def single_bug_analysis():
61
- """Handle single bug analysis"""
62
- uploaded_file = st.file_uploader("Upload a bug photo", type=['png', 'jpg', 'jpeg'], key="single")
63
-
64
- if uploaded_file:
65
- try:
66
- # Load and display image
67
- image = Image.open(uploaded_file)
68
- col1, col2 = st.columns(2)
69
-
70
- with col1:
71
- st.image(image, caption="Uploaded Image", use_container_width=True)
72
-
73
- with col2:
74
- with st.spinner("Analyzing your bug..."):
75
- # Get predictions
76
- prediction, confidence = st.session_state.model.predict(image)
77
- print(f"Prediction: {prediction}, Confidence: {confidence}")
78
-
79
- st.success("Analysis Complete!")
80
- st.markdown("### Identified Species")
81
- st.markdown(f"**{prediction}**")
82
- st.markdown(f"Confidence: {confidence:.2f}%")
83
-
84
- # Only show ecological impact for known insects
85
- if prediction != "Unknown Insect" and prediction != "Error Processing Image":
86
- severity = get_severity_prediction(prediction)
87
- st.markdown("### Ecological Impact")
88
- severity_color = {
89
- "Low": "green",
90
- "Medium": "orange",
91
- "High": "red",
92
- "Unknown": "gray"
93
- }
94
- st.markdown(
95
- f"Severity: <span style='color: {severity_color[severity]}'>{severity}</span>",
96
- unsafe_allow_html=True
97
- )
98
-
99
- # Display species information
100
- if prediction != "Unknown Insect" and prediction != "Error Processing Image":
101
- st.markdown("### About This Species")
102
- species_info = st.session_state.model.get_species_info(prediction)
103
- st.markdown(species_info)
104
-
105
- # Display visualization
106
- st.markdown("### Feature Highlights")
107
- gradcam = st.session_state.model.get_gradcam(image)
108
- st.image(gradcam, caption="Important Features", use_container_width=True)
109
-
110
- except Exception as e:
111
- st.error(f"Error processing image: {str(e)}")
112
- st.info("Please try uploading a different image.")
113
-
114
- def compare_bugs():
115
- """Handle bug comparison"""
116
- col1, col2 = st.columns(2)
117
-
118
- with col1:
119
- file1 = st.file_uploader("Upload first bug photo", type=['png', 'jpg', 'jpeg'], key="compare1")
120
- if file1:
121
- try:
122
- image1 = Image.open(file1)
123
- st.image(image1, caption="First Bug", use_container_width=True)
124
- except Exception as e:
125
- st.error(f"Error loading first image: {str(e)}")
126
- return
127
-
128
- with col2:
129
- file2 = st.file_uploader("Upload second bug photo", type=['png', 'jpg', 'jpeg'], key="compare2")
130
- if file2:
131
- try:
132
- image2 = Image.open(file2)
133
- st.image(image2, caption="Second Bug", use_container_width=True)
134
- except Exception as e:
135
- st.error(f"Error loading second image: {str(e)}")
136
- return
137
-
138
- if file1 and file2:
139
  try:
140
- with st.spinner("Generating comparison..."):
141
- # Get predictions
142
- pred1, conf1 = st.session_state.model.predict(image1)
143
- pred2, conf2 = st.session_state.model.predict(image2)
144
-
145
- if pred1 not in ["Unknown Insect", "Error Processing Image"] and \
146
- pred2 not in ["Unknown Insect", "Error Processing Image"]:
147
-
148
- # Display results
149
- st.markdown("### Comparison Results")
150
- comp_col1, comp_col2 = st.columns(2)
151
-
152
- with comp_col1:
153
- st.markdown(f"**Species 1**: {pred1}")
154
- st.markdown(f"Confidence: {conf1:.2f}%")
155
- gradcam1 = st.session_state.model.get_gradcam(image1)
156
- st.image(gradcam1, caption="Feature Highlights - Bug 1", use_container_width=True)
157
-
158
- with comp_col2:
159
- st.markdown(f"**Species 2**: {pred2}")
160
- st.markdown(f"Confidence: {conf2:.2f}%")
161
- gradcam2 = st.session_state.model.get_gradcam(image2)
162
- st.image(gradcam2, caption="Feature Highlights - Bug 2", use_container_width=True)
163
-
164
- # Display comparison
165
- st.markdown("### Key Differences")
166
- st.markdown(st.session_state.model.get_species_info(pred1))
167
- st.markdown(st.session_state.model.get_species_info(pred2))
168
- else:
169
- st.warning("Unable to generate meaningful comparison due to low confidence predictions.")
170
-
171
  except Exception as e:
172
- st.error(f"Error comparing images: {str(e)}")
173
- st.info("Please try uploading different images or try again.")
174
 
175
- if __name__ == "__main__":
176
- main()
 
 
 
1
+ # 3. model_utils.py
2
+ # Model management (loading, prediction, and species information)
3
+ from transformers import ViTForImageClassification
4
  from PIL import Image
5
+ import torch
6
+ from dataset_utils import load_species_descriptions
7
+
8
+ class BugClassifier:
9
+ def __init__(self, model_path="google/vit-base-patch16-224"):
10
+ self.model = ViTForImageClassification.from_pretrained(model_path)
11
+ self.model.eval()
12
+ self.labels = [
13
+ "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
14
+ "Japanese Beetle", "Garden Spider", "Green Grasshopper",
15
+ "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
16
+ ]
17
+ # Dynamically load species descriptions
18
+ self.species_descriptions = load_species_descriptions()
19
+
20
+ def predict(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  try:
22
+ processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
23
+ inputs = processor(images=image, return_tensors="pt")
24
+ with torch.no_grad():
25
+ outputs = self.model(**inputs)
26
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
27
+ confidence, predicted_idx = probabilities.max(dim=1)
28
+ confidence = confidence.item() * 100
29
+ predicted_label = self.labels[predicted_idx.item()]
30
+
31
+ if confidence < 30:
32
+ return "Unknown Insect", confidence
33
+
34
+ return predicted_label, confidence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
+ return "Error Processing Image", 0.0
 
37
 
38
+ def get_species_info(self, species):
39
+ return self.species_descriptions.get(
40
+ species, "Information not available. Consider updating your dataset for this species."
41
+ )