dalybuilds commited on
Commit
37f5146
Β·
verified Β·
1 Parent(s): f88bb49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -1,25 +1,28 @@
1
  import streamlit as st
2
- import torch
3
  from PIL import Image
4
- import io
5
  import numpy as np
6
- from model_utils import BugClassifier, generate_gradcam, get_severity_prediction
7
  from transformers import AutoFeatureExtractor
8
 
9
  # Page configuration
10
  st.set_page_config(
11
  page_title="Bug-O-Scope πŸ”πŸž",
12
  page_icon="πŸ”",
13
- layout="wide"
 
14
  )
15
 
16
  # Initialize session state
17
- if 'model' not in st.session_state:
 
18
  try:
19
- st.session_state.model = BugClassifier()
20
- st.session_state.feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
21
  except Exception as e:
22
  st.error(f"Error loading model: {str(e)}")
 
 
 
 
23
 
24
  def main():
25
  # Header
@@ -49,6 +52,7 @@ def main():
49
  compare_bugs()
50
 
51
  def single_bug_analysis():
 
52
  uploaded_file = st.file_uploader("Upload a bug photo", type=['png', 'jpg', 'jpeg'], key="single")
53
 
54
  if uploaded_file:
@@ -76,19 +80,27 @@ def single_bug_analysis():
76
  "Medium": "orange",
77
  "High": "red"
78
  }
79
- st.markdown(f"Severity: <span style='color: {severity_color[severity]}'>{severity}</span>",
80
- unsafe_allow_html=True)
 
 
81
 
82
  # Generate and display species information
83
  st.markdown("### About This Species")
84
  species_info = st.session_state.model.get_species_info(prediction)
85
  st.markdown(species_info)
86
 
 
 
 
 
 
87
  except Exception as e:
88
  st.error(f"Error processing image: {str(e)}")
89
  st.info("Please try uploading a different image.")
90
 
91
  def compare_bugs():
 
92
  col1, col2 = st.columns(2)
93
 
94
  with col1:
@@ -119,8 +131,8 @@ def compare_bugs():
119
  pred2, conf2 = st.session_state.model.predict(image2)
120
 
121
  # Generate Grad-CAM visualizations
122
- gradcam1 = generate_gradcam(image1, st.session_state.model)
123
- gradcam2 = generate_gradcam(image2, st.session_state.model)
124
 
125
  # Display results
126
  st.markdown("### Comparison Results")
 
1
  import streamlit as st
 
2
  from PIL import Image
 
3
  import numpy as np
4
+ from model_utils import BugClassifier, get_severity_prediction
5
  from transformers import AutoFeatureExtractor
6
 
7
  # Page configuration
8
  st.set_page_config(
9
  page_title="Bug-O-Scope πŸ”πŸž",
10
  page_icon="πŸ”",
11
+ layout="wide",
12
+ initial_sidebar_state="expanded"
13
  )
14
 
15
  # Initialize session state
16
+ @st.cache_resource
17
+ def load_model():
18
  try:
19
+ return BugClassifier(), AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
 
20
  except Exception as e:
21
  st.error(f"Error loading model: {str(e)}")
22
+ return None, None
23
+
24
+ if 'model' not in st.session_state:
25
+ st.session_state.model, st.session_state.feature_extractor = load_model()
26
 
27
  def main():
28
  # Header
 
52
  compare_bugs()
53
 
54
  def single_bug_analysis():
55
+ """Handle single bug analysis"""
56
  uploaded_file = st.file_uploader("Upload a bug photo", type=['png', 'jpg', 'jpeg'], key="single")
57
 
58
  if uploaded_file:
 
80
  "Medium": "orange",
81
  "High": "red"
82
  }
83
+ st.markdown(
84
+ f"Severity: <span style='color: {severity_color[severity]}'>{severity}</span>",
85
+ unsafe_allow_html=True
86
+ )
87
 
88
  # Generate and display species information
89
  st.markdown("### About This Species")
90
  species_info = st.session_state.model.get_species_info(prediction)
91
  st.markdown(species_info)
92
 
93
+ # Display Grad-CAM visualization
94
+ st.markdown("### Feature Highlights")
95
+ gradcam = st.session_state.model.get_gradcam(image)
96
+ st.image(gradcam, caption="Important Features", use_container_width=True)
97
+
98
  except Exception as e:
99
  st.error(f"Error processing image: {str(e)}")
100
  st.info("Please try uploading a different image.")
101
 
102
  def compare_bugs():
103
+ """Handle bug comparison"""
104
  col1, col2 = st.columns(2)
105
 
106
  with col1:
 
131
  pred2, conf2 = st.session_state.model.predict(image2)
132
 
133
  # Generate Grad-CAM visualizations
134
+ gradcam1 = st.session_state.model.get_gradcam(image1)
135
+ gradcam2 = st.session_state.model.get_gradcam(image2)
136
 
137
  # Display results
138
  st.markdown("### Comparison Results")