Update tasks/image.py
Browse files- tasks/image.py +3 -3
tasks/image.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import torch
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.optim as optim
|
4 |
from torchvision import transforms
|
@@ -93,10 +94,9 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
93 |
|
94 |
# Load and prepare the dataset
|
95 |
dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
|
96 |
-
|
97 |
# Split dataset
|
98 |
-
|
99 |
-
test_dataset = dataset["val"]
|
100 |
|
101 |
# Start tracking emissions
|
102 |
tracker.start()
|
|
|
1 |
import torch
|
2 |
+
import os
|
3 |
import torch.nn as nn
|
4 |
import torch.optim as optim
|
5 |
from torchvision import transforms
|
|
|
94 |
|
95 |
# Load and prepare the dataset
|
96 |
dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
|
97 |
+
|
98 |
# Split dataset
|
99 |
+
test_dataset = dataset["test"]
|
|
|
100 |
|
101 |
# Start tracking emissions
|
102 |
tracker.start()
|