henry000 commited on
Commit
360e5d7
·
1 Parent(s): 3b24cc8

👽️ [Update] cudaamp, most torch not support amp yet

Browse files
Files changed (1) hide show
  1. yolo/tools/solver.py +3 -3
yolo/tools/solver.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  from loguru import logger
12
  from pycocotools.coco import COCO
13
  from torch import Tensor, distributed
14
- from torch.amp import GradScaler, autocast
15
  from torch.nn.parallel import DistributedDataParallel as DDP
16
  from torch.utils.data import DataLoader
17
 
@@ -63,13 +63,13 @@ class ModelTrainer:
63
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
64
  else:
65
  self.ema = None
66
- self.scaler = GradScaler(str(self.device))
67
 
68
  def train_one_batch(self, images: Tensor, targets: Tensor):
69
  images, targets = images.to(self.device), targets.to(self.device)
70
  self.optimizer.zero_grad()
71
 
72
- with autocast(str(self.device)):
73
  predicts = self.model(images)
74
  aux_predicts = self.vec2box(predicts["AUX"])
75
  main_predicts = self.vec2box(predicts["Main"])
 
11
  from loguru import logger
12
  from pycocotools.coco import COCO
13
  from torch import Tensor, distributed
14
+ from torch.cuda.amp import GradScaler, autocast
15
  from torch.nn.parallel import DistributedDataParallel as DDP
16
  from torch.utils.data import DataLoader
17
 
 
63
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
64
  else:
65
  self.ema = None
66
+ self.scaler = GradScaler()
67
 
68
  def train_one_batch(self, images: Tensor, targets: Tensor):
69
  images, targets = images.to(self.device), targets.to(self.device)
70
  self.optimizer.zero_grad()
71
 
72
+ with autocast():
73
  predicts = self.model(images)
74
  aux_predicts = self.vec2box(predicts["AUX"])
75
  main_predicts = self.vec2box(predicts["Main"])