NonSittinon commited on
Commit
2b87c33
·
verified ·
1 Parent(s): fb9839a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import Wav2Vec2Processor, HubertModel
6
+ from sklearn.preprocessing import StandardScaler
7
+
8
+ import gradio as gr
9
+
10
+ # โหลด processor, model HuBERT และ fine-tuned classifier
11
+ processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
12
+ model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
13
+ model.eval()
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # โหลดไฟล์ scaler และ model classifier
18
+ import joblib
19
+ scaler = joblib.load("scaler.joblib")
20
+ model_cls = torch.load("classifier.pth", map_location=device)
21
+ model_cls.eval()
22
+
23
+ def extract_mean_embedding(wav_path):
24
+ waveform, sample_rate = torchaudio.load(wav_path)
25
+ waveform = waveform.squeeze()
26
+ inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
+ embedding = outputs.last_hidden_state
30
+ mean_embedding = embedding.mean(dim=1).squeeze().numpy()
31
+ return mean_embedding
32
+
33
+ def predict_water_status(file):
34
+ vec = extract_mean_embedding(file).reshape(1, -1)
35
+ vec_scaled = scaler.transform(vec)
36
+ vec_tensor = torch.tensor(vec_scaled, dtype=torch.float32).to(device)
37
+ with torch.no_grad():
38
+ outputs = model_cls(vec_tensor)
39
+ pred = outputs.argmax(dim=1).item()
40
+ return "🌵 ขาดน้ำ" if pred == 0 else "💧 มีน้ำเพียงพอ"
41
+
42
+ with gr.Blocks() as interface:
43
+ gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned)")
44
+ gr.Markdown("อัปโหลดเสียงพืชเพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ")
45
+
46
+ audio_input = gr.Audio(type="filepath", label="🎧 อัปโหลดเสียงพืช (.wav)")
47
+ output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2)
48
+
49
+ with gr.Row():
50
+ submit_btn = gr.Button("Submit")
51
+ clear_btn = gr.Button("Clear")
52
+
53
+ submit_btn.click(
54
+ fn=predict_water_status,
55
+ inputs=audio_input,
56
+ outputs=output_text
57
+ )
58
+
59
+ clear_btn.click(
60
+ fn=lambda: (None, ""),
61
+ inputs=[],
62
+ outputs=[audio_input, output_text]
63
+ )
64
+
65
+ interface.launch()