File size: 3,316 Bytes
2b87c33
 
 
 
 
 
 
 
b96b8e8
2b87c33
19da833
 
 
 
 
 
 
 
 
 
 
 
 
 
b96b8e8
2b87c33
 
 
 
 
 
b96b8e8
2b87c33
7ccb591
b96b8e8
 
7ccb591
 
19da833
 
b96b8e8
 
 
 
2b87c33
 
 
 
 
b96b8e8
2b87c33
 
 
b96b8e8
2b87c33
b96b8e8
2b87c33
 
 
 
 
 
b96b8e8
2b87c33
 
 
b96b8e8
 
 
 
 
 
 
2b87c33
 
b96b8e8
 
2b87c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2Processor, HubertModel
from sklearn.preprocessing import StandardScaler

import gradio as gr
import joblib

# ✅ ประกาศ PlantSoundClassifier (เพิ่มตรงนี้)
class PlantSoundClassifier(torch.nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=128, output_dim=2):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 👇 โหลด processor และโมเดล HuBERT
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 👇 โหลด scaler
scaler = joblib.load("scaler.joblib")


# ✅ ให้ torch รู้จัก class (safe!)
torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier})

# ✅ โหลด state_dict ของโมเดล
model_cls_state_dict = torch.load("classifier.pth", map_location=device)

# ✅ สร้าง instance classifier แล้วโหลด state_dict
model_cls = PlantSoundClassifier().to(device)
model_cls.load_state_dict(model_cls_state_dict)
model_cls.eval()

def extract_mean_embedding(wav_path):
    waveform, sample_rate = torchaudio.load(wav_path)
    waveform = waveform.squeeze()

    inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    embedding = outputs.last_hidden_state
    mean_embedding = embedding.mean(dim=1).squeeze().cpu().numpy()
    return mean_embedding

def predict_water_status(file):
    vec = extract_mean_embedding(file).reshape(1, -1)
    vec_scaled = scaler.transform(vec)
    vec_tensor = torch.tensor(vec_scaled, dtype=torch.float32).to(device)

    with torch.no_grad():
        outputs = model_cls(vec_tensor)
        pred = outputs.argmax(dim=1).item()

    if pred == 0:
        return "🌵 ขาดน้ำ"
    elif pred == 1:
        return "💧 มีน้ำเพียงพอ"
    else:
        return "⚠️ ไม่ทราบสถานะ"

with gr.Blocks() as interface:
    gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned HuBERT)")
    gr.Markdown("อัปโหลดเสียงพืช (.wav) เพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ")

    audio_input = gr.Audio(type="filepath", label="🎧 อัปโหลดเสียงพืช (.wav)")
    output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2)

    with gr.Row():
        submit_btn = gr.Button("Submit")
        clear_btn = gr.Button("Clear")

    submit_btn.click(
        fn=predict_water_status,
        inputs=audio_input,
        outputs=output_text
    )

    clear_btn.click(
        fn=lambda: (None, ""),
        inputs=[],
        outputs=[audio_input, output_text]
    )

interface.launch()