File size: 3,573 Bytes
4d6e8c2 fe4a4cb 3b09640 5ad4868 fe4a4cb 4d6e8c2 fe4a4cb 5ad4868 4d6e8c2 3b09640 4d6e8c2 0388c00 1c33274 70f5f26 fe4a4cb 5ad4868 4d6e8c2 fe4a4cb 70f5f26 5ad4868 4d6e8c2 fe4a4cb 4d6e8c2 fe4a4cb 0388c00 fe4a4cb 3b09640 5ad4868 baff0a5 5ad4868 360633d fe4a4cb 5ad4868 360633d 5ad4868 360633d fe4a4cb 360633d fe4a4cb 4d6e8c2 fe4a4cb 70f5f26 fe4a4cb 4d6e8c2 70f5f26 4d6e8c2 fe4a4cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
import os
import torch
from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from .utils.preprocess import get_dataloader
from .models.model import ChainsawDetector
from dotenv import load_dotenv
load_dotenv()
router = APIRouter()
DESCRIPTION = "ChainsawDetector"
ROUTE = "/audio"
@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
"""
Evaluate audio classification for rainforest sound detection.
Current Model: ChainsawDetector
- STFT -> PCEN -> split into small time chunks -> CNN+LSTM for each chunk -> dense -> prediction
"""
# Get space info
username, space_url = get_space_info()
# Define the label mapping
LABEL_MAPPING = {
"chainsaw": 0,
"environment": 1
}
# Load and prepare the dataset
# Because the dataset is gated, we need to use the HF_TOKEN environment variable to authenticate
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"
split='test'
test_dataset = load_dataset(request.dataset_name, split=split, token=os.getenv("HF_TOKEN"))
dataloader = get_dataloader(test_dataset, device, batch_size=batch_size, shuffle=False)
# Load model
model = ChainsawDetector(batch_size).to(device, dtype=torch.bfloat16)
model = torch.compile(model)
model.load_state_dict(torch.load('tasks/models/final-bf16.pth', weights_only=True))
model.eval()
num_correct = 0
num_samples = len(test_dataset)
# Start tracking emissions
tracker.start()
tracker.start_task("inference")
#--------------------------------------------------------------------------------------------
# YOUR MODEL INFERENCE CODE HERE
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
#--------------------------------------------------------------------------------------------
predictions = []
with torch.no_grad():
for (X, y) in dataloader:
X = X.to(device, dtype=torch.bfloat16)
y = y.to(device, dtype=torch.bfloat16)
predictions = model(X)
num_correct += (y==predictions).sum() # count correct predictions
#--------------------------------------------------------------------------------------------
# YOUR MODEL INFERENCE STOPS HERE
#--------------------------------------------------------------------------------------------
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Calculate accuracy
accuracy = float(num_correct) / float(num_samples)
# Prepare results dictionary
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed
}
}
return results |