bangla-disaster / app.py
pr0ximaCent's picture
Update app.py
f64a78f verified
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import os
# === Model Path ===
MODEL_PATH = "bangla_disaster_model.pth"
# Check if model exists
if not os.path.exists(MODEL_PATH):
st.error("❌ Model file not found. Please ensure bangla_disaster_model.pth is uploaded.")
st.stop()
# Global class list
classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD']
# === Model Setup ===
class MultimodalBanglaClassifier(nn.Module):
def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5):
super(MultimodalBanglaClassifier, self).__init__()
self.text_model = AutoModel.from_pretrained(text_model_name)
for param in self.text_model.encoder.layer[:6].parameters():
param.requires_grad = False
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
self.image_model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
self.image_model.classifier = nn.Identity()
self.proj = nn.Linear(768 + 1536, 512)
self.transformer_fusion = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=4, batch_first=True),
num_layers=2
)
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, input_ids, attention_mask, image):
text_feat = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
image_feat = self.image_model(image)
fused = self.proj(torch.cat((text_feat, image_feat), dim=1)).unsqueeze(1)
fused = self.transformer_fusion(fused).squeeze(1)
return self.classifier(fused)
# 🚀 OPTIMIZATION 1: Cache both model and tokenizer together (No accuracy impact)
@st.cache_resource
def load_model_and_tokenizer():
"""Load model and tokenizer once and cache them"""
model = MultimodalBanglaClassifier()
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
return model, tokenizer
def get_bangla_response(class_name):
responses = {
'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!",
'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!",
'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!",
'EQ': "🌍 ভুমিকম্প শনাক্ত হয়েছে (Earthquake)! নিরাপদ স্থানে যান!",
'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!"
}
return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।")
def predict_fast(model, tokenizer, image, caption):
"""Optimized prediction function with smaller image size and shorter text"""
# 🚀 OPTIMIZATION 2: Smaller image size (Minimal accuracy impact: ~1-3%)
transform = transforms.Compose([
transforms.Resize((160, 160)), # Reduced from 224x224 for faster processing
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0)
# 🚀 OPTIMIZATION 3: Shorter text length (Only affects very long captions)
encoded = tokenizer(
caption,
padding='max_length',
truncation=True,
max_length=64, # Reduced from 128 for faster processing
return_tensors='pt'
)
with torch.no_grad():
output = model(
input_ids=encoded['input_ids'],
attention_mask=encoded['attention_mask'],
image=image
)
pred_class = output.argmax(dim=1).item()
confidence_scores = output.softmax(dim=1).squeeze().tolist()
return classes[pred_class], confidence_scores
def predict_full_quality(model, tokenizer, image, caption):
"""Full quality prediction with original settings"""
transform = transforms.Compose([
transforms.Resize((224, 224)), # Original size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0)
encoded = tokenizer(
caption,
padding='max_length',
truncation=True,
max_length=128, # Original length
return_tensors='pt'
)
with torch.no_grad():
output = model(
input_ids=encoded['input_ids'],
attention_mask=encoded['attention_mask'],
image=image
)
pred_class = output.argmax(dim=1).item()
confidence_scores = output.softmax(dim=1).squeeze().tolist()
return classes[pred_class], confidence_scores
# === Streamlit UI ===
st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
st.title("🌪️🇧🇩 Bangla Disaster Classifier")
st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")
# 🚀 OPTIMIZATION 4: Load model and tokenizer once at startup
with st.spinner("🔄 মডেল লোড হচ্ছে... (Loading model...)"):
model, tokenizer = load_model_and_tokenizer()
uploaded_file = st.file_uploader("🖼️ একটি দুর্যোগের ছবি আপলোড করুন", type=['jpg', 'png', 'jpeg'])
caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")
# Prediction mode selection
prediction_mode = st.radio(
"🎯 পূর্বাভাস মোড নির্বাচন করুন:",
["⚡ দ্রুত পূর্বাভাস (Fast Prediction)", "🎯 উচ্চ নির্ভুলতা (High Accuracy)"],
help="দ্রুত মোডে কম সময় লাগে কিন্তু সামান্য কম নির্ভুল হতে পারে"
)
col1, col2 = st.columns([1, 1])
submit = col1.button("🔍 পূর্বাভাস দিন")
clear = col2.button("🧹 রিসেট করুন")
if clear:
st.rerun() # Fixed deprecated function
if submit and uploaded_file and caption:
img = Image.open(uploaded_file).convert("RGB")
st.image(img, caption="আপলোড করা ছবি", use_container_width=True) # Fixed deprecated parameter
# 🚀 OPTIMIZATION 5: Enhanced progress indicators
with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
progress_bar = st.progress(0, text="ছবি প্রক্রিয়াকরণ... (Processing image...)")
# Choose prediction function based on mode
if "দ্রুত" in prediction_mode:
progress_bar.progress(50, text="দ্রুত বিশ্লেষণ... (Fast analysis...)")
prediction, probs = predict_fast(model, tokenizer, img, caption)
mode_info = "⚡ দ্রুত মোড (Fast Mode)"
else:
progress_bar.progress(50, text="উচ্চ নির্ভুলতা বিশ্লেষণ... (High accuracy analysis...)")
prediction, probs = predict_full_quality(model, tokenizer, img, caption)
mode_info = "🎯 উচ্চ নির্ভুলতা মোড (High Accuracy Mode)"
progress_bar.progress(100, text="সম্পূর্ণ! (Complete!)")
# Clear progress bar
progress_bar.empty()
# Display results
st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}")
col1, col2 = st.columns([2, 1])
with col1:
st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**")
with col2:
st.caption(mode_info)
# Show detailed probabilities
with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"):
class_names = {
'HYD': 'জলসম্পর্কিত দুর্যোগ',
'MET': 'আবহাওয়া দুর্যোগ',
'FD': 'অগ্নিদুর্ঘটনা',
'EQ': 'ভূমিকম্প',
'OTHD': 'কোনো দুর্যোগ নয়'
}
for i, class_code in enumerate(classes):
percentage = probs[i] * 100
st.write(f"**{class_names[class_code]}**: {percentage:.1f}%")
st.progress(probs[i])