NonSittinon commited on
Commit
19da833
·
verified ·
1 Parent(s): 9e63030

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -8,6 +8,20 @@ from sklearn.preprocessing import StandardScaler
8
  import gradio as gr
9
  import joblib
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # 👇 โหลด processor และโมเดล HuBERT
12
  processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
13
  model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
@@ -22,8 +36,8 @@ scaler = joblib.load("scaler.joblib")
22
  # ✅ ให้ torch รู้จัก class (safe!)
23
  torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier})
24
 
25
- # ✅ โหลด state_dict ของโมเดลอย่างปลอดภัย
26
- model_cls_state_dict = torch.load("classifier.pth", map_location=device, weights_only=True)
27
 
28
  # ✅ สร้าง instance classifier แล้วโหลด state_dict
29
  model_cls = PlantSoundClassifier().to(device)
 
8
  import gradio as gr
9
  import joblib
10
 
11
+ # ✅ ประกาศ PlantSoundClassifier (เพิ่มตรงนี้)
12
+ class PlantSoundClassifier(torch.nn.Module):
13
+ def __init__(self, input_dim=1024, hidden_dim=128, output_dim=2):
14
+ super().__init__()
15
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
16
+ self.relu = torch.nn.ReLU()
17
+ self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
18
+
19
+ def forward(self, x):
20
+ x = self.fc1(x)
21
+ x = self.relu(x)
22
+ x = self.fc2(x)
23
+ return x
24
+
25
  # 👇 โหลด processor และโมเดล HuBERT
26
  processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
27
  model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
 
36
  # ✅ ให้ torch รู้จัก class (safe!)
37
  torch.serialization.add_safe_globals({"__main__.PlantSoundClassifier": PlantSoundClassifier})
38
 
39
+ # ✅ โหลด state_dict ของโมเดล
40
+ model_cls_state_dict = torch.load("classifier.pth", map_location=device)
41
 
42
  # ✅ สร้าง instance classifier แล้วโหลด state_dict
43
  model_cls = PlantSoundClassifier().to(device)