Nicolas Denier commited on
Commit
360633d
·
1 Parent(s): baff0a5

fix accuracy compute

Browse files
Files changed (1) hide show
  1. tasks/audio.py +8 -8
tasks/audio.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
- from sklearn.metrics import accuracy_score
5
  import os
6
  import torch
7
 
@@ -48,7 +48,8 @@ async def evaluate_audio(request: AudioEvaluationRequest):
48
  model = torch.compile(model)
49
  model.load_state_dict(torch.load('tasks/models/final-bf16.pth', weights_only=True))
50
  model.eval()
51
-
 
52
  # Start tracking emissions
53
  tracker.start()
54
  tracker.start_task("inference")
@@ -59,14 +60,14 @@ async def evaluate_audio(request: AudioEvaluationRequest):
59
  #--------------------------------------------------------------------------------------------
60
 
61
  predictions = []
62
- with torch.no_grad():#, torch.amp.autocast(device_type=device):
63
  for (X, y) in dataloader:
64
  X = X.to(device, dtype=torch.bfloat16)
65
  y = y.to(device, dtype=torch.bfloat16)
66
 
67
- predictions.append(model(X))
68
- predictions = torch.cat(predictions, dim=0)
69
-
70
  #--------------------------------------------------------------------------------------------
71
  # YOUR MODEL INFERENCE STOPS HERE
72
  #--------------------------------------------------------------------------------------------
@@ -75,8 +76,7 @@ async def evaluate_audio(request: AudioEvaluationRequest):
75
  emissions_data = tracker.stop_task()
76
 
77
  # Calculate accuracy
78
- true_labels = test_dataset["label"]
79
- accuracy = accuracy_score(true_labels, predictions)
80
 
81
  # Prepare results dictionary
82
  results = {
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
+ #from sklearn.metrics import accuracy_score
5
  import os
6
  import torch
7
 
 
48
  model = torch.compile(model)
49
  model.load_state_dict(torch.load('tasks/models/final-bf16.pth', weights_only=True))
50
  model.eval()
51
+ num_correct = 0
52
+ num_samples = len(test_dataset)
53
  # Start tracking emissions
54
  tracker.start()
55
  tracker.start_task("inference")
 
60
  #--------------------------------------------------------------------------------------------
61
 
62
  predictions = []
63
+ with torch.no_grad():
64
  for (X, y) in dataloader:
65
  X = X.to(device, dtype=torch.bfloat16)
66
  y = y.to(device, dtype=torch.bfloat16)
67
 
68
+ predictions = model(X)
69
+ num_correct += (y==predictions).sum() # count correct predictions
70
+
71
  #--------------------------------------------------------------------------------------------
72
  # YOUR MODEL INFERENCE STOPS HERE
73
  #--------------------------------------------------------------------------------------------
 
76
  emissions_data = tracker.stop_task()
77
 
78
  # Calculate accuracy
79
+ accuracy = float(num_correct) / float(num_samples)
 
80
 
81
  # Prepare results dictionary
82
  results = {