NonSittinon commited on
Commit
b96b8e8
·
verified ·
1 Parent(s): 7ccb591

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -6,46 +6,63 @@ from transformers import Wav2Vec2Processor, HubertModel
6
  from sklearn.preprocessing import StandardScaler
7
 
8
  import gradio as gr
 
9
 
10
- # โหลด processor, model HuBERT และ fine-tuned classifier
11
  processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
12
  model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
13
  model.eval()
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # โหลดไฟล์ scaler และ model classifier
18
- import joblib
19
  scaler = joblib.load("scaler.joblib")
20
- from your_module import PlantSoundClassifier # นำเข้า class ของคุณเอง
21
 
 
 
 
 
22
  torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier})
23
 
24
- model_cls = torch.load("classifier.pth", map_location=device, weights_only=True)
 
 
 
 
 
25
  model_cls.eval()
26
 
27
  def extract_mean_embedding(wav_path):
28
  waveform, sample_rate = torchaudio.load(wav_path)
29
  waveform = waveform.squeeze()
 
30
  inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
31
  with torch.no_grad():
32
  outputs = model(**inputs)
 
33
  embedding = outputs.last_hidden_state
34
- mean_embedding = embedding.mean(dim=1).squeeze().numpy()
35
  return mean_embedding
36
 
37
  def predict_water_status(file):
38
  vec = extract_mean_embedding(file).reshape(1, -1)
39
  vec_scaled = scaler.transform(vec)
40
  vec_tensor = torch.tensor(vec_scaled, dtype=torch.float32).to(device)
 
41
  with torch.no_grad():
42
  outputs = model_cls(vec_tensor)
43
  pred = outputs.argmax(dim=1).item()
44
- return "🌵 ขาดน้ำ" if pred == 0 else "💧 มีน้ำเพียงพอ"
 
 
 
 
 
 
45
 
46
  with gr.Blocks() as interface:
47
- gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned)")
48
- gr.Markdown("อัปโหลดเสียงพืชเพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ")
49
 
50
  audio_input = gr.Audio(type="filepath", label="🎧 อัปโหลดเสียงพืช (.wav)")
51
  output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2)
 
6
  from sklearn.preprocessing import StandardScaler
7
 
8
  import gradio as gr
9
+ import joblib
10
 
11
+ # 👇 โหลด processor และโมเดล HuBERT
12
  processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
13
  model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
14
  model.eval()
15
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ # 👇 โหลด scaler
 
19
  scaler = joblib.load("scaler.joblib")
 
20
 
21
+ # 👇 โหลด classifier
22
+ from your_module import PlantSoundClassifier # แก้เป็น module/class ที่คุณสร้างเอง
23
+
24
+ # ✅ ให้ torch รู้จัก class (safe!)
25
  torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier})
26
 
27
+ # โหลด state_dict ของโมเดลอย่างปลอดภัย
28
+ model_cls_state_dict = torch.load("classifier.pth", map_location=device, weights_only=True)
29
+
30
+ # ✅ สร้าง instance classifier แล้วโหลด state_dict
31
+ model_cls = PlantSoundClassifier().to(device)
32
+ model_cls.load_state_dict(model_cls_state_dict)
33
  model_cls.eval()
34
 
35
  def extract_mean_embedding(wav_path):
36
  waveform, sample_rate = torchaudio.load(wav_path)
37
  waveform = waveform.squeeze()
38
+
39
  inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
40
  with torch.no_grad():
41
  outputs = model(**inputs)
42
+
43
  embedding = outputs.last_hidden_state
44
+ mean_embedding = embedding.mean(dim=1).squeeze().cpu().numpy()
45
  return mean_embedding
46
 
47
  def predict_water_status(file):
48
  vec = extract_mean_embedding(file).reshape(1, -1)
49
  vec_scaled = scaler.transform(vec)
50
  vec_tensor = torch.tensor(vec_scaled, dtype=torch.float32).to(device)
51
+
52
  with torch.no_grad():
53
  outputs = model_cls(vec_tensor)
54
  pred = outputs.argmax(dim=1).item()
55
+
56
+ if pred == 0:
57
+ return "🌵 ขาดน้ำ"
58
+ elif pred == 1:
59
+ return "💧 มีน้ำเพียงพอ"
60
+ else:
61
+ return "⚠️ ไม่ทราบสถานะ"
62
 
63
  with gr.Blocks() as interface:
64
+ gr.Markdown("## 🌱 Plant Sound Classifier (Fine-tuned HuBERT)")
65
+ gr.Markdown("อัปโหลดเสียงพืช (.wav) เพื่อทำนายสถานะ: ขาดน้ำ หรือ มีน้ำเพียงพอ")
66
 
67
  audio_input = gr.Audio(type="filepath", label="🎧 อัปโหลดเสียงพืช (.wav)")
68
  output_text = gr.Textbox(label="📋 ผลการทำนาย", lines=2)