Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torchaudio | |
import numpy as np | |
from transformers import Wav2Vec2Processor, HubertModel | |
from sklearn.preprocessing import StandardScaler | |
import gradio as gr | |
import joblib | |
# ✅ ประกาศ PlantSoundClassifier (เพิ่มตรงนี้) | |
class PlantSoundClassifier(torch.nn.Module): | |
def __init__(self, input_dim=1024, hidden_dim=128, output_dim=2): | |
super().__init__() | |
self.fc1 = torch.nn.Linear(input_dim, hidden_dim) | |
self.relu = torch.nn.ReLU() | |
self.fc2 = torch.nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.relu(x) | |
x = self.fc2(x) | |
return x | |
# 👇 โหลด processor และโมเดล HuBERT | |
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") | |
model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") | |
model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 👇 โหลด scaler | |
scaler = joblib.load("scaler.joblib") | |
# ✅ ให้ torch รู้จัก class (safe!) | |
torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier}) | |
# ✅ โหลด state_dict ของโมเดล | |
model_cls_state_dict = torch.load("classifier.pth", map_location=device) | |
# ✅ สร้าง instance classifier แล้วโหลด state_dict | |
model_cls = PlantSoundClassifier().to(device) | |
model_cls.load_state_dict(model_cls_state_dict) | |
model_cls.eval() | |
def extract_mean_embedding(wav_path): | |
waveform, sample_rate = torchaudio.load(wav_path) | |
waveform = waveform.squeeze() | |
inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state | |
mean_embedding = embedding.mean(dim=1).squeeze().cpu().numpy() | |
return mean_embedding | |
def predict_water_status(file): | |
vec = extract_mean_embedding(file).reshape(1, -1) | |
vec_scaled = scaler.transform(vec) | |
vec_tensor = torch.tensor(vec_scaled, dtype=torch.float32).to(device) | |
with torch.no_grad(): | |
outputs = model_cls(vec_tensor) | |
pred = outputs.argmax(dim=1).item() | |
if pred == 0: | |
return "🌵 ขาดน้ำ" | |
elif pred == 1: | |
return "💧 มีน้ำเพียงพอ" | |
else: | |
return "⚠️ ไม่ทราบสถานะ" | |
with gr.Blocks() as interface: | |
gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned HuBERT)") | |
gr.Markdown("อัปโหลดเสียงพืช (.wav) เพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ") | |
audio_input = gr.Audio(type="filepath", label="🎧 อัปโหลดเสียงพืช (.wav)") | |
output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear") | |
submit_btn.click( | |
fn=predict_water_status, | |
inputs=audio_input, | |
outputs=output_text | |
) | |
clear_btn.click( | |
fn=lambda: (None, ""), | |
inputs=[], | |
outputs=[audio_input, output_text] | |
) | |
interface.launch() | |