import os import torch import torchaudio import numpy as np from transformers import Wav2Vec2Processor, HubertModel from sklearn.preprocessing import StandardScaler import gradio as gr import torch.nn as nn # 1️⃣ โมเดล classifier class PlantSoundClassifier(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(1024, 128) self.relu = nn.ReLU() self.fc2 = nn.Linear(128, 2) # 2 classes: dry (0) or enough water (1) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x # 2️⃣ โหลด HuBERT model_name = "facebook/hubert-large-ls960-ft" processor = Wav2Vec2Processor.from_pretrained(model_name) model_hubert = HubertModel.from_pretrained(model_name) model_hubert.eval() # 3️⃣ โหลด classifier ที่ fine-tune แล้ว model_cls = PlantSoundClassifier() state_dict = torch.load("classifier.pt", map_location="cpu",weights_only=False ) model_cls.eval() # 4️⃣ โหลด scaler scaler = np.load("scaler.npy", allow_pickle=True).item() # 5️⃣ ฟังก์ชันดึง mean embedding def extract_mean_embedding(wav_path): waveform, sample_rate = torchaudio.load(wav_path) waveform = waveform.squeeze() inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model_hubert(**inputs) embedding = outputs.last_hidden_state mean_embedding = embedding.mean(dim=1).squeeze().numpy() return mean_embedding # 6️⃣ ฟังก์ชันทำนาย def predict_water_status(file): mean_emb = extract_mean_embedding(file).reshape(1, -1) mean_emb_scaled = scaler.transform(mean_emb) mean_emb_tensor = torch.tensor(mean_emb_scaled, dtype=torch.float32) with torch.no_grad(): output = model_cls(mean_emb_tensor) pred = output.argmax(dim=1).item() return "🌵 ขาดน้ำ" if pred == 0 else "💧 มีน้ำเพียงพอ" # 7️⃣ สร้าง Gradio Interface with gr.Blocks() as demo: gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned)") gr.Markdown("อัปโหลดเสียงพืช (.wav) เพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ") audio_input = gr.Audio(type="filepath", label="🎧 Upload Plant Sound (.wav)") output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2) with gr.Row(): submit_btn = gr.Button("Predict") clear_btn = gr.Button("Clear") submit_btn.click( fn=predict_water_status, inputs=audio_input, outputs=output_text ) clear_btn.click( fn=lambda: (None, ""), inputs=[], outputs=[audio_input, output_text] ) # 8️⃣ รัน Gradio demo.launch()