Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModel | |
import torch.nn as nn | |
import os | |
# === Model Path === | |
MODEL_PATH = "bangla_disaster_model.pth" | |
# Check if model exists | |
if not os.path.exists(MODEL_PATH): | |
st.error("❌ Model file not found. Please ensure bangla_disaster_model.pth is uploaded.") | |
st.stop() | |
# Global class list | |
classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD'] | |
# === Model Setup === | |
class MultimodalBanglaClassifier(nn.Module): | |
def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5): | |
super(MultimodalBanglaClassifier, self).__init__() | |
self.text_model = AutoModel.from_pretrained(text_model_name) | |
for param in self.text_model.encoder.layer[:6].parameters(): | |
param.requires_grad = False | |
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights | |
self.image_model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1) | |
self.image_model.classifier = nn.Identity() | |
self.proj = nn.Linear(768 + 1536, 512) | |
self.transformer_fusion = nn.TransformerEncoder( | |
nn.TransformerEncoderLayer(d_model=512, nhead=4, batch_first=True), | |
num_layers=2 | |
) | |
self.classifier = nn.Sequential( | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(256, num_classes) | |
) | |
def forward(self, input_ids, attention_mask, image): | |
text_feat = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] | |
image_feat = self.image_model(image) | |
fused = self.proj(torch.cat((text_feat, image_feat), dim=1)).unsqueeze(1) | |
fused = self.transformer_fusion(fused).squeeze(1) | |
return self.classifier(fused) | |
# 🚀 OPTIMIZATION 1: Cache both model and tokenizer together (No accuracy impact) | |
def load_model_and_tokenizer(): | |
"""Load model and tokenizer once and cache them""" | |
model = MultimodalBanglaClassifier() | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base") | |
return model, tokenizer | |
def get_bangla_response(class_name): | |
responses = { | |
'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!", | |
'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!", | |
'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!", | |
'EQ': "🌍 ভুমিকম্প শনাক্ত হয়েছে (Earthquake)! নিরাপদ স্থানে যান!", | |
'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!" | |
} | |
return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।") | |
def predict_fast(model, tokenizer, image, caption): | |
"""Optimized prediction function with smaller image size and shorter text""" | |
# 🚀 OPTIMIZATION 2: Smaller image size (Minimal accuracy impact: ~1-3%) | |
transform = transforms.Compose([ | |
transforms.Resize((160, 160)), # Reduced from 224x224 for faster processing | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
image = transform(image).unsqueeze(0) | |
# 🚀 OPTIMIZATION 3: Shorter text length (Only affects very long captions) | |
encoded = tokenizer( | |
caption, | |
padding='max_length', | |
truncation=True, | |
max_length=64, # Reduced from 128 for faster processing | |
return_tensors='pt' | |
) | |
with torch.no_grad(): | |
output = model( | |
input_ids=encoded['input_ids'], | |
attention_mask=encoded['attention_mask'], | |
image=image | |
) | |
pred_class = output.argmax(dim=1).item() | |
confidence_scores = output.softmax(dim=1).squeeze().tolist() | |
return classes[pred_class], confidence_scores | |
def predict_full_quality(model, tokenizer, image, caption): | |
"""Full quality prediction with original settings""" | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Original size | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
image = transform(image).unsqueeze(0) | |
encoded = tokenizer( | |
caption, | |
padding='max_length', | |
truncation=True, | |
max_length=128, # Original length | |
return_tensors='pt' | |
) | |
with torch.no_grad(): | |
output = model( | |
input_ids=encoded['input_ids'], | |
attention_mask=encoded['attention_mask'], | |
image=image | |
) | |
pred_class = output.argmax(dim=1).item() | |
confidence_scores = output.softmax(dim=1).squeeze().tolist() | |
return classes[pred_class], confidence_scores | |
# === Streamlit UI === | |
st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered") | |
st.title("🌪️🇧🇩 Bangla Disaster Classifier") | |
st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।") | |
# 🚀 OPTIMIZATION 4: Load model and tokenizer once at startup | |
with st.spinner("🔄 মডেল লোড হচ্ছে... (Loading model...)"): | |
model, tokenizer = load_model_and_tokenizer() | |
uploaded_file = st.file_uploader("🖼️ একটি দুর্যোগের ছবি আপলোড করুন", type=['jpg', 'png', 'jpeg']) | |
caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "") | |
# Prediction mode selection | |
prediction_mode = st.radio( | |
"🎯 পূর্বাভাস মোড নির্বাচন করুন:", | |
["⚡ দ্রুত পূর্বাভাস (Fast Prediction)", "🎯 উচ্চ নির্ভুলতা (High Accuracy)"], | |
help="দ্রুত মোডে কম সময় লাগে কিন্তু সামান্য কম নির্ভুল হতে পারে" | |
) | |
col1, col2 = st.columns([1, 1]) | |
submit = col1.button("🔍 পূর্বাভাস দিন") | |
clear = col2.button("🧹 রিসেট করুন") | |
if clear: | |
st.rerun() # Fixed deprecated function | |
if submit and uploaded_file and caption: | |
img = Image.open(uploaded_file).convert("RGB") | |
st.image(img, caption="আপলোড করা ছবি", use_container_width=True) # Fixed deprecated parameter | |
# 🚀 OPTIMIZATION 5: Enhanced progress indicators | |
with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"): | |
progress_bar = st.progress(0, text="ছবি প্রক্রিয়াকরণ... (Processing image...)") | |
# Choose prediction function based on mode | |
if "দ্রুত" in prediction_mode: | |
progress_bar.progress(50, text="দ্রুত বিশ্লেষণ... (Fast analysis...)") | |
prediction, probs = predict_fast(model, tokenizer, img, caption) | |
mode_info = "⚡ দ্রুত মোড (Fast Mode)" | |
else: | |
progress_bar.progress(50, text="উচ্চ নির্ভুলতা বিশ্লেষণ... (High accuracy analysis...)") | |
prediction, probs = predict_full_quality(model, tokenizer, img, caption) | |
mode_info = "🎯 উচ্চ নির্ভুলতা মোড (High Accuracy Mode)" | |
progress_bar.progress(100, text="সম্পূর্ণ! (Complete!)") | |
# Clear progress bar | |
progress_bar.empty() | |
# Display results | |
st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**") | |
with col2: | |
st.caption(mode_info) | |
# Show detailed probabilities | |
with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"): | |
class_names = { | |
'HYD': 'জলসম্পর্কিত দুর্যোগ', | |
'MET': 'আবহাওয়া দুর্যোগ', | |
'FD': 'অগ্নিদুর্ঘটনা', | |
'EQ': 'ভূমিকম্প', | |
'OTHD': 'কোনো দুর্যোগ নয়' | |
} | |
for i, class_code in enumerate(classes): | |
percentage = probs[i] * 100 | |
st.write(f"**{class_names[class_code]}**: {percentage:.1f}%") | |
st.progress(probs[i]) |