Spaces:
Runtime error
Runtime error
File size: 5,084 Bytes
c9843cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
from threading import local
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from .utils import get_lr
def log_rmse(outputs, labels, loss):
with torch.no_grad():
# 将小于1的值设成1,使得取对数时数值更稳定
clipped_preds = torch.max(outputs, torch.tensor(1.0))
rmse = torch.sqrt(2 * loss(clipped_preds.log(), labels.log()).mean())
return rmse
def fit_one_epoch(model_train, model, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
total_loss = 0
total_rmse = 0
val_loss = 0
val_rmse = 0
# 定义损失函数
loss = nn.MSELoss()
if local_rank == 0:
print('Start Train')
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
model_train.train()
for iteration, batch in enumerate(gen):
if iteration >= epoch_step:
break
images, targets = batch
with torch.no_grad():
if cuda:
images = images.cuda(local_rank)
targets = targets.cuda(local_rank)
#----------------------#
# 清零梯度
#----------------------#
optimizer.zero_grad()
if not fp16:
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(images)
#----------------------#
# 计算损失
#----------------------#
loss_value = loss(outputs, targets)
loss_value.backward()
optimizer.step()
else:
from torch.cuda.amp import autocast
with autocast():
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(images)
#----------------------#
# 计算损失
#----------------------#
loss_value = loss(outputs, targets)
#----------------------#
# 反向传播
#----------------------#
scaler.scale(loss_value).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss_value.item()
# 计算对数均方根误差
with torch.no_grad():
rmse = log_rmse(outputs, targets, loss)
total_rmse += rmse.item()
if local_rank == 0:
pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
'total_rmse': total_rmse / (iteration + 1),
'lr' : get_lr(optimizer)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Train')
print('Start Validation')
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
model_train.eval()
for iteration, batch in enumerate(gen_val):
if iteration >= epoch_step_val:
break
images, targets = batch
with torch.no_grad():
if cuda:
images = images.cuda(local_rank)
targets = targets.cuda(local_rank)
optimizer.zero_grad()
outputs = model_train(images)
loss_value = loss(outputs, targets)
val_loss += loss_value.item()
rmse = log_rmse(outputs, targets, loss)
val_rmse += rmse.item()
if local_rank == 0:
pbar.set_postfix(**{'total_loss': val_loss / (iteration + 1),
'total_rmse': val_rmse / (iteration + 1),
'lr' : get_lr(optimizer)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Validation')
loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
#-----------------------------------------------#
# 保存权值
#-----------------------------------------------#
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
print('Save best model to best_epoch_weights.pth')
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
|