Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -48,20 +48,15 @@ class MultimodalBanglaClassifier(nn.Module):
|
|
48 |
fused = self.transformer_fusion(fused).squeeze(1)
|
49 |
return self.classifier(fused)
|
50 |
|
51 |
-
# Cache model and tokenizer
|
52 |
-
@st.cache_resource
|
53 |
def load_model_and_tokenizer():
|
54 |
"""Load model and tokenizer once and cache them"""
|
55 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
model = MultimodalBanglaClassifier()
|
57 |
-
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
58 |
-
# Apply dynamic quantization for CPU
|
59 |
-
if device == torch.device("cpu"):
|
60 |
-
model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
|
61 |
-
model.to(device)
|
62 |
model.eval()
|
63 |
tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
|
64 |
-
return model, tokenizer
|
65 |
|
66 |
def get_bangla_response(class_name):
|
67 |
responses = {
|
@@ -73,68 +68,70 @@ def get_bangla_response(class_name):
|
|
73 |
}
|
74 |
return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।")
|
75 |
|
76 |
-
def predict_fast(model, tokenizer, image, caption
|
77 |
-
"""Optimized prediction with smaller image size and shorter text"""
|
|
|
78 |
transform = transforms.Compose([
|
79 |
-
transforms.Resize((
|
80 |
transforms.ToTensor(),
|
81 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
82 |
])
|
83 |
-
image = transform(image).unsqueeze(0)
|
84 |
|
|
|
85 |
encoded = tokenizer(
|
86 |
caption,
|
87 |
padding='max_length',
|
88 |
truncation=True,
|
89 |
-
max_length=
|
90 |
return_tensors='pt'
|
91 |
)
|
92 |
-
input_ids = encoded['input_ids'].to(device)
|
93 |
-
attention_mask = encoded['attention_mask'].to(device)
|
94 |
|
95 |
with torch.no_grad():
|
96 |
-
output = model(
|
|
|
|
|
|
|
|
|
97 |
pred_class = output.argmax(dim=1).item()
|
98 |
confidence_scores = output.softmax(dim=1).squeeze().tolist()
|
99 |
-
|
100 |
|
101 |
-
def predict_full_quality(model, tokenizer, image, caption
|
102 |
"""Full quality prediction with original settings"""
|
103 |
transform = transforms.Compose([
|
104 |
-
transforms.Resize((224, 224)),
|
105 |
transforms.ToTensor(),
|
106 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
107 |
])
|
108 |
-
image = transform(image).unsqueeze(0)
|
109 |
encoded = tokenizer(
|
110 |
caption,
|
111 |
padding='max_length',
|
112 |
truncation=True,
|
113 |
-
max_length=128,
|
114 |
return_tensors='pt'
|
115 |
)
|
116 |
-
input_ids = encoded['input_ids'].to(device)
|
117 |
-
attention_mask = encoded['attention_mask'].to(device)
|
118 |
-
|
119 |
with torch.no_grad():
|
120 |
-
output = model(
|
|
|
|
|
|
|
|
|
121 |
pred_class = output.argmax(dim=1).item()
|
122 |
confidence_scores = output.softmax(dim=1).squeeze().tolist()
|
123 |
-
|
124 |
|
125 |
# === Streamlit UI ===
|
126 |
st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
|
127 |
st.title("🌪️🇧🇩 Bangla Disaster Classifier")
|
128 |
st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")
|
129 |
|
130 |
-
#
|
131 |
-
if 'prediction' not in st.session_state:
|
132 |
-
st.session_state.prediction = None
|
133 |
-
st.session_state.probs = None
|
134 |
-
|
135 |
-
# Load model, tokenizer, and device
|
136 |
with st.spinner("🔄 মডেল লোড হচ্ছে... (Loading model...)"):
|
137 |
-
model, tokenizer
|
138 |
|
139 |
uploaded_file = st.file_uploader("🖼️ একটি দুর্যোগের ছবি আপলোড করুন", type=['jpg', 'png', 'jpeg'])
|
140 |
caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")
|
@@ -143,7 +140,7 @@ caption = st.text_area("✍️ বাংলায় একটি ক্যা
|
|
143 |
prediction_mode = st.radio(
|
144 |
"🎯 পূর্বাভাস মোড নির্বাচন করুন:",
|
145 |
["⚡ দ্রুত পূর্বাভাস (Fast Prediction)", "🎯 উচ্চ নির্ভুলতা (High Accuracy)"],
|
146 |
-
help="দ্রুত মোডে কম সময় লাগে কিন্তু সামান্য কম নির্ভুল হতে পারে
|
147 |
)
|
148 |
|
149 |
col1, col2 = st.columns([1, 1])
|
@@ -151,31 +148,41 @@ submit = col1.button("🔍 পূর্বাভাস দিন")
|
|
151 |
clear = col2.button("🧹 রিসেট করুন")
|
152 |
|
153 |
if clear:
|
154 |
-
st.
|
155 |
-
st.session_state.probs = None
|
156 |
-
st.rerun()
|
157 |
|
158 |
if submit and uploaded_file and caption:
|
159 |
img = Image.open(uploaded_file).convert("RGB")
|
160 |
-
st.image(img, caption="আপলোড করা ছবি",
|
161 |
|
|
|
162 |
with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
|
|
|
|
|
|
|
163 |
if "দ্রুত" in prediction_mode:
|
164 |
-
|
|
|
165 |
mode_info = "⚡ দ্রুত মোড (Fast Mode)"
|
166 |
else:
|
167 |
-
|
|
|
168 |
mode_info = "🎯 উচ্চ নির্ভুলতা মোড (High Accuracy Mode)"
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
col1, col2 = st.columns([2, 1])
|
174 |
with col1:
|
175 |
-
st.markdown(f"#### 📊 সম্ভাব্যতা: **{
|
176 |
with col2:
|
177 |
st.caption(mode_info)
|
178 |
|
|
|
179 |
with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"):
|
180 |
class_names = {
|
181 |
'HYD': 'জলসম্পর্কিত দুর্যোগ',
|
@@ -184,8 +191,8 @@ if st.session_state.prediction:
|
|
184 |
'EQ': 'ভূমিকম্প',
|
185 |
'OTHD': 'কোনো দুর্যোগ নয়'
|
186 |
}
|
|
|
187 |
for i, class_code in enumerate(classes):
|
188 |
-
percentage =
|
189 |
st.write(f"**{class_names[class_code]}**: {percentage:.1f}%")
|
190 |
-
st.progress(
|
191 |
-
|
|
|
48 |
fused = self.transformer_fusion(fused).squeeze(1)
|
49 |
return self.classifier(fused)
|
50 |
|
51 |
+
# 🚀 OPTIMIZATION 1: Cache both model and tokenizer together (No accuracy impact)
|
52 |
+
@st.cache_resource
|
53 |
def load_model_and_tokenizer():
|
54 |
"""Load model and tokenizer once and cache them"""
|
|
|
55 |
model = MultimodalBanglaClassifier()
|
56 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
|
|
|
|
|
|
|
|
|
57 |
model.eval()
|
58 |
tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
|
59 |
+
return model, tokenizer
|
60 |
|
61 |
def get_bangla_response(class_name):
|
62 |
responses = {
|
|
|
68 |
}
|
69 |
return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।")
|
70 |
|
71 |
+
def predict_fast(model, tokenizer, image, caption):
|
72 |
+
"""Optimized prediction function with smaller image size and shorter text"""
|
73 |
+
# 🚀 OPTIMIZATION 2: Smaller image size (Minimal accuracy impact: ~1-3%)
|
74 |
transform = transforms.Compose([
|
75 |
+
transforms.Resize((160, 160)), # Reduced from 224x224 for faster processing
|
76 |
transforms.ToTensor(),
|
77 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
78 |
+
std=[0.229, 0.224, 0.225])
|
79 |
])
|
80 |
+
image = transform(image).unsqueeze(0)
|
81 |
|
82 |
+
# 🚀 OPTIMIZATION 3: Shorter text length (Only affects very long captions)
|
83 |
encoded = tokenizer(
|
84 |
caption,
|
85 |
padding='max_length',
|
86 |
truncation=True,
|
87 |
+
max_length=64, # Reduced from 128 for faster processing
|
88 |
return_tensors='pt'
|
89 |
)
|
|
|
|
|
90 |
|
91 |
with torch.no_grad():
|
92 |
+
output = model(
|
93 |
+
input_ids=encoded['input_ids'],
|
94 |
+
attention_mask=encoded['attention_mask'],
|
95 |
+
image=image
|
96 |
+
)
|
97 |
pred_class = output.argmax(dim=1).item()
|
98 |
confidence_scores = output.softmax(dim=1).squeeze().tolist()
|
99 |
+
return classes[pred_class], confidence_scores
|
100 |
|
101 |
+
def predict_full_quality(model, tokenizer, image, caption):
|
102 |
"""Full quality prediction with original settings"""
|
103 |
transform = transforms.Compose([
|
104 |
+
transforms.Resize((224, 224)), # Original size
|
105 |
transforms.ToTensor(),
|
106 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
107 |
+
std=[0.229, 0.224, 0.225])
|
108 |
])
|
109 |
+
image = transform(image).unsqueeze(0)
|
110 |
encoded = tokenizer(
|
111 |
caption,
|
112 |
padding='max_length',
|
113 |
truncation=True,
|
114 |
+
max_length=128, # Original length
|
115 |
return_tensors='pt'
|
116 |
)
|
|
|
|
|
|
|
117 |
with torch.no_grad():
|
118 |
+
output = model(
|
119 |
+
input_ids=encoded['input_ids'],
|
120 |
+
attention_mask=encoded['attention_mask'],
|
121 |
+
image=image
|
122 |
+
)
|
123 |
pred_class = output.argmax(dim=1).item()
|
124 |
confidence_scores = output.softmax(dim=1).squeeze().tolist()
|
125 |
+
return classes[pred_class], confidence_scores
|
126 |
|
127 |
# === Streamlit UI ===
|
128 |
st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
|
129 |
st.title("🌪️🇧🇩 Bangla Disaster Classifier")
|
130 |
st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")
|
131 |
|
132 |
+
# 🚀 OPTIMIZATION 4: Load model and tokenizer once at startup
|
|
|
|
|
|
|
|
|
|
|
133 |
with st.spinner("🔄 মডেল লোড হচ্ছে... (Loading model...)"):
|
134 |
+
model, tokenizer = load_model_and_tokenizer()
|
135 |
|
136 |
uploaded_file = st.file_uploader("🖼️ একটি দুর্যোগের ছবি আপলোড করুন", type=['jpg', 'png', 'jpeg'])
|
137 |
caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")
|
|
|
140 |
prediction_mode = st.radio(
|
141 |
"🎯 পূর্বাভাস মোড নির্বাচন করুন:",
|
142 |
["⚡ দ্রুত পূর্বাভাস (Fast Prediction)", "🎯 উচ্চ নির্ভুলতা (High Accuracy)"],
|
143 |
+
help="দ্রুত মোডে কম সময় লাগে কিন্তু সামান্য কম নির্ভুল হতে পারে"
|
144 |
)
|
145 |
|
146 |
col1, col2 = st.columns([1, 1])
|
|
|
148 |
clear = col2.button("🧹 রিসেট করুন")
|
149 |
|
150 |
if clear:
|
151 |
+
st.rerun() # Fixed deprecated function
|
|
|
|
|
152 |
|
153 |
if submit and uploaded_file and caption:
|
154 |
img = Image.open(uploaded_file).convert("RGB")
|
155 |
+
st.image(img, caption="আপলোড করা ছবি", use_container_width=True) # Fixed deprecated parameter
|
156 |
|
157 |
+
# 🚀 OPTIMIZATION 5: Enhanced progress indicators
|
158 |
with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
|
159 |
+
progress_bar = st.progress(0, text="ছবি প্রক্রিয়াকরণ... (Processing image...)")
|
160 |
+
|
161 |
+
# Choose prediction function based on mode
|
162 |
if "দ্রুত" in prediction_mode:
|
163 |
+
progress_bar.progress(50, text="দ্রুত বিশ্লেষণ... (Fast analysis...)")
|
164 |
+
prediction, probs = predict_fast(model, tokenizer, img, caption)
|
165 |
mode_info = "⚡ দ্রুত মোড (Fast Mode)"
|
166 |
else:
|
167 |
+
progress_bar.progress(50, text="উচ্চ নির্ভুলতা বিশ্লেষণ... (High accuracy analysis...)")
|
168 |
+
prediction, probs = predict_full_quality(model, tokenizer, img, caption)
|
169 |
mode_info = "🎯 উচ্চ নির্ভুলতা মোড (High Accuracy Mode)"
|
170 |
+
|
171 |
+
progress_bar.progress(100, text="সম্পূর্ণ! (Complete!)")
|
172 |
+
|
173 |
+
# Clear progress bar
|
174 |
+
progress_bar.empty()
|
175 |
+
|
176 |
+
# Display results
|
177 |
+
st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}")
|
178 |
|
179 |
col1, col2 = st.columns([2, 1])
|
180 |
with col1:
|
181 |
+
st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**")
|
182 |
with col2:
|
183 |
st.caption(mode_info)
|
184 |
|
185 |
+
# Show detailed probabilities
|
186 |
with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"):
|
187 |
class_names = {
|
188 |
'HYD': 'জলসম্পর্কিত দুর্যোগ',
|
|
|
191 |
'EQ': 'ভূমিকম্প',
|
192 |
'OTHD': 'কোনো দুর্যোগ নয়'
|
193 |
}
|
194 |
+
|
195 |
for i, class_code in enumerate(classes):
|
196 |
+
percentage = probs[i] * 100
|
197 |
st.write(f"**{class_names[class_code]}**: {percentage:.1f}%")
|
198 |
+
st.progress(probs[i])
|
|