Spaces:
Sleeping
Sleeping
File size: 3,316 Bytes
2b87c33 b96b8e8 2b87c33 19da833 b96b8e8 2b87c33 b96b8e8 2b87c33 7ccb591 b96b8e8 7ccb591 19da833 b96b8e8 2b87c33 b96b8e8 2b87c33 b96b8e8 2b87c33 b96b8e8 2b87c33 b96b8e8 2b87c33 b96b8e8 2b87c33 b96b8e8 2b87c33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2Processor, HubertModel
from sklearn.preprocessing import StandardScaler
import gradio as gr
import joblib
# ✅ ประกาศ PlantSoundClassifier (เพิ่มตรงนี้)
class PlantSoundClassifier(torch.nn.Module):
def __init__(self, input_dim=1024, hidden_dim=128, output_dim=2):
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 👇 โหลด processor และโมเดล HuBERT
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 👇 โหลด scaler
scaler = joblib.load("scaler.joblib")
# ✅ ให้ torch รู้จัก class (safe!)
torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier})
# ✅ โหลด state_dict ของโมเดล
model_cls_state_dict = torch.load("classifier.pth", map_location=device)
# ✅ สร้าง instance classifier แล้วโหลด state_dict
model_cls = PlantSoundClassifier().to(device)
model_cls.load_state_dict(model_cls_state_dict)
model_cls.eval()
def extract_mean_embedding(wav_path):
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform.squeeze()
inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state
mean_embedding = embedding.mean(dim=1).squeeze().cpu().numpy()
return mean_embedding
def predict_water_status(file):
vec = extract_mean_embedding(file).reshape(1, -1)
vec_scaled = scaler.transform(vec)
vec_tensor = torch.tensor(vec_scaled, dtype=torch.float32).to(device)
with torch.no_grad():
outputs = model_cls(vec_tensor)
pred = outputs.argmax(dim=1).item()
if pred == 0:
return "🌵 ขาดน้ำ"
elif pred == 1:
return "💧 มีน้ำเพียงพอ"
else:
return "⚠️ ไม่ทราบสถานะ"
with gr.Blocks() as interface:
gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned HuBERT)")
gr.Markdown("อัปโหลดเสียงพืช (.wav) เพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ")
audio_input = gr.Audio(type="filepath", label="🎧 อัปโหลดเสียงพืช (.wav)")
output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(
fn=predict_water_status,
inputs=audio_input,
outputs=output_text
)
clear_btn.click(
fn=lambda: (None, ""),
inputs=[],
outputs=[audio_input, output_text]
)
interface.launch()
|