NonSittinon's picture
Update app.py
19da833 verified
raw
history blame
3.32 kB
import os
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2Processor, HubertModel
from sklearn.preprocessing import StandardScaler
import gradio as gr
import joblib
# ✅ ประกาศ PlantSoundClassifier (เพิ่มตรงนี้)
class PlantSoundClassifier(torch.nn.Module):
def __init__(self, input_dim=1024, hidden_dim=128, output_dim=2):
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 👇 โหลด processor และโมเดล HuBERT
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 👇 โหลด scaler
scaler = joblib.load("scaler.joblib")
# ✅ ให้ torch รู้จัก class (safe!)
torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier})
# ✅ โหลด state_dict ของโมเดล
model_cls_state_dict = torch.load("classifier.pth", map_location=device)
# ✅ สร้าง instance classifier แล้วโหลด state_dict
model_cls = PlantSoundClassifier().to(device)
model_cls.load_state_dict(model_cls_state_dict)
model_cls.eval()
def extract_mean_embedding(wav_path):
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform.squeeze()
inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state
mean_embedding = embedding.mean(dim=1).squeeze().cpu().numpy()
return mean_embedding
def predict_water_status(file):
vec = extract_mean_embedding(file).reshape(1, -1)
vec_scaled = scaler.transform(vec)
vec_tensor = torch.tensor(vec_scaled, dtype=torch.float32).to(device)
with torch.no_grad():
outputs = model_cls(vec_tensor)
pred = outputs.argmax(dim=1).item()
if pred == 0:
return "🌵 ขาดน้ำ"
elif pred == 1:
return "💧 มีน้ำเพียงพอ"
else:
return "⚠️ ไม่ทราบสถานะ"
with gr.Blocks() as interface:
gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned HuBERT)")
gr.Markdown("อัปโหลดเสียงพืช (.wav) เพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ")
audio_input = gr.Audio(type="filepath", label="🎧 อัปโหลดเสียงพืช (.wav)")
output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(
fn=predict_water_status,
inputs=audio_input,
outputs=output_text
)
clear_btn.click(
fn=lambda: (None, ""),
inputs=[],
outputs=[audio_input, output_text]
)
interface.launch()