👽️ [Update] cudaamp, most torch not support amp yet
Browse files- yolo/tools/solver.py +3 -3
yolo/tools/solver.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
from loguru import logger
|
12 |
from pycocotools.coco import COCO
|
13 |
from torch import Tensor, distributed
|
14 |
-
from torch.amp import GradScaler, autocast
|
15 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
16 |
from torch.utils.data import DataLoader
|
17 |
|
@@ -63,13 +63,13 @@ class ModelTrainer:
|
|
63 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
64 |
else:
|
65 |
self.ema = None
|
66 |
-
self.scaler = GradScaler(
|
67 |
|
68 |
def train_one_batch(self, images: Tensor, targets: Tensor):
|
69 |
images, targets = images.to(self.device), targets.to(self.device)
|
70 |
self.optimizer.zero_grad()
|
71 |
|
72 |
-
with autocast(
|
73 |
predicts = self.model(images)
|
74 |
aux_predicts = self.vec2box(predicts["AUX"])
|
75 |
main_predicts = self.vec2box(predicts["Main"])
|
|
|
11 |
from loguru import logger
|
12 |
from pycocotools.coco import COCO
|
13 |
from torch import Tensor, distributed
|
14 |
+
from torch.cuda.amp import GradScaler, autocast
|
15 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
16 |
from torch.utils.data import DataLoader
|
17 |
|
|
|
63 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
64 |
else:
|
65 |
self.ema = None
|
66 |
+
self.scaler = GradScaler()
|
67 |
|
68 |
def train_one_batch(self, images: Tensor, targets: Tensor):
|
69 |
images, targets = images.to(self.device), targets.to(self.device)
|
70 |
self.optimizer.zero_grad()
|
71 |
|
72 |
+
with autocast():
|
73 |
predicts = self.model(images)
|
74 |
aux_predicts = self.vec2box(predicts["AUX"])
|
75 |
main_predicts = self.vec2box(predicts["Main"])
|