aiface's picture
Update main.py
8b59e1e
raw
history blame
19.9 kB
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2020 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" TCN for lipreading"""
import os
import time
import random
import argparse # ๋ช…๋ นํ–‰ ์ธ์ž๋ฅผ ํŒŒ์‹ฑํ•ด์ฃผ๋Š” ๋ชจ๋“ˆ
import numpy as np
from tqdm import tqdm # ์ž‘์—…์ง„ํ–‰๋ฅ  ํ‘œ์‹œํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
import torch # ํŒŒ์ดํ† ์น˜
import torch.nn as nn # ํด๋ž˜์Šค # attribute ๋ฅผ ํ™œ์šฉํ•ด state ๋ฅผ ์ €์žฅํ•˜๊ณ  ํ™œ์šฉ
import torch.nn.functional as F # ํ•จ์ˆ˜ # ์ธ์Šคํ„ด์Šคํ™”์‹œํ‚ฌ ํ•„์š”์—†์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅ
from lipreading.utils import get_save_folder
from lipreading.utils import load_json, save2npz
from lipreading.utils import load_model, CheckpointSaver
from lipreading.utils import get_logger, update_logger_batch
from lipreading.utils import showLR, calculateNorm2, AverageMeter
from lipreading.model import Lipreading
from lipreading.mixup import mixup_data, mixup_criterion
from lipreading.optim_utils import get_optimizer, CosineScheduler
from lipreading.dataloaders import get_data_loaders, get_preprocessing_pipelines
from pathlib import Path
import wandb # ํ•™์Šต ๊ด€๋ฆฌ ํˆด (Loss, Acc ์ž๋™ ์ €์žฅ)
# ์ธ์ž๊ฐ’์„ ๋ฐ›์•„์„œ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
def load_args(default_config=None):
# ์ธ์ž๊ฐ’์„ ๋ฐ›์„ ์ˆ˜ ์žˆ๋Š” ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
parser = argparse.ArgumentParser(description='Pytorch Lipreading ')
# ์ž…๋ ฅ๋ฐ›์„ ์ธ์ž๊ฐ’ ๋ชฉ๋ก
# -- dataset config
parser.add_argument('--dataset', default='lrw', help='dataset selection')
parser.add_argument('--num-classes', type=int, default=30, help='Number of classes')
parser.add_argument('--modality', default='video', choices=['video', 'raw_audio'], help='choose the modality')
# -- directory
parser.add_argument('--data-dir', default='./datasets/visual', help='Loaded data directory')
parser.add_argument('--label-path', type=str, default='./labels/30VietnameseSort.txt', help='Path to txt file with labels')
parser.add_argument('--annonation-direc', default=None, help='Loaded data directory')
# -- model config
parser.add_argument('--backbone-type', type=str, default='resnet', choices=['resnet', 'shufflenet'], help='Architecture used for backbone')
parser.add_argument('--relu-type', type=str, default='relu', choices=['relu','prelu'], help='what relu to use' )
parser.add_argument('--width-mult', type=float, default=1.0, help='Width multiplier for mobilenets and shufflenets')
# -- TCN config
parser.add_argument('--tcn-kernel-size', type=int, nargs="+", help='Kernel to be used for the TCN module')
parser.add_argument('--tcn-num-layers', type=int, default=4, help='Number of layers on the TCN module')
parser.add_argument('--tcn-dropout', type=float, default=0.2, help='Dropout value for the TCN module')
parser.add_argument('--tcn-dwpw', default=False, action='store_true', help='If True, use the depthwise seperable convolution in TCN architecture')
parser.add_argument('--tcn-width-mult', type=int, default=1, help='TCN width multiplier')
# -- train
parser.add_argument('--training-mode', default='tcn', help='tcn')
parser.add_argument('--batch-size', type=int, default=8, help='Mini-batch size') # dafault=32 ์—์„œ default=8 (OOM ๋ฐฉ์ง€) ๋กœ ๋ณ€๊ฒฝ
parser.add_argument('--optimizer',type=str, default='adamw', choices = ['adam','sgd','adamw'])
parser.add_argument('--lr', default=3e-4, type=float, help='initial learning rate')
parser.add_argument('--init-epoch', default=0, type=int, help='epoch to start at')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs') # dafault=80 ์—์„œ default=10 (ํ…Œ์ŠคํŠธ ์šฉ๋„) ๋กœ ๋ณ€๊ฒฝ
parser.add_argument('--test', default=False, action='store_true', help='training mode')
parser.add_argument('--save-dir', type=Path, default=Path('/kaggle/working/result/'))
# -- mixup
parser.add_argument('--alpha', default=0.4, type=float, help='interpolation strength (uniform=1., ERM=0.)')
# -- test
parser.add_argument('--model-path', type=str, default=None, help='Pretrained model pathname')
parser.add_argument('--allow-size-mismatch', default=False, action='store_true',
help='If True, allows to init from model with mismatching weight tensors. Useful to init from model with diff. number of classes')
# -- feature extractor
parser.add_argument('--extract-feats', default=False, action='store_true', help='Feature extractor')
parser.add_argument('--mouth-patch-path', type=str, default=None, help='Path to the mouth ROIs, assuming the file is saved as numpy.array')
parser.add_argument('--mouth-embedding-out-path', type=str, default=None, help='Save mouth embeddings to a specificed path')
# -- json pathname
parser.add_argument('--config-path', type=str, default=None, help='Model configuration with json format')
# -- other vars
parser.add_argument('--interval', default=50, type=int, help='display interval')
parser.add_argument('--workers', default=2, type=int, help='number of data loading workers') # dafault=8 ์—์„œ default=2 (GCP core 4๊ฐœ์˜ ์ ˆ๋ฐ˜) ๋กœ ๋ณ€๊ฒฝ
# paths
parser.add_argument('--logging-dir', type=str, default='/kaggle/working/train_logs', help = 'path to the directory in which to save the log file')
# ์ž…๋ ฅ๋ฐ›์€ ์ธ์ž๊ฐ’์„ args์— ์ €์žฅ (type: namespace)
args = parser.parse_args()
return args
args = load_args() # args ํŒŒ์‹ฑ ๋ฐ ๋กœ๋“œ
# ์‹คํ—˜ ์žฌํ˜„์„ ์œ„ํ•ด์„œ ๋‚œ์ˆ˜ ๊ณ ์ •
torch.manual_seed(1) # ๋ฉ”์ธ ํ”„๋ ˆ์ž„์›Œํฌ์ธ pytorch ์—์„œ random seed ๊ณ ์ •
np.random.seed(1) # numpy ์—์„œ random seed ๊ณ ์ •
random.seed(1) # python random ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ random seed ๊ณ ์ •
# ์ฐธ๊ณ : ์‹คํ—˜ ์žฌํ˜„ํ•˜๋ ค๋ฉด torch.backends.cudnn.deterministic = True, torch.backends.cudnn.benchmark = False ์ด์–ด์•ผ ํ•จ
torch.backends.cudnn.benchmark = True # ๋‚ด์žฅ๋œ cudnn ์ž๋™ ํŠœ๋„ˆ๋ฅผ ํ™œ์„ฑํ™”ํ•˜์—ฌ, ํ•˜๋“œ์›จ์–ด์— ๋งž๊ฒŒ ์‚ฌ์šฉํ•  ์ตœ์ƒ์˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜(ํ…์„œ ํฌ๊ธฐ๋‚˜ conv ์—ฐ์‚ฐ์— ๋งž๊ฒŒ)์„ ์ฐพ์Œ
# feature ์ถ”์ถœ
def extract_feats(model):
"""
:rtype: FloatTensor
"""
model.eval() # evaluation ๊ณผ์ •์—์„œ ์‚ฌ์šฉํ•˜์ง€ ์•Š์•„์•ผ ํ•˜๋Š” layer๋“ค์„ ์•Œ์•„์„œ off ์‹œํ‚ค๋„๋ก ํ•˜๋Š” ํ•จ์ˆ˜
preprocessing_func = get_preprocessing_pipelines()['test'] # test ์ „์ฒ˜๋ฆฌ
mouth_patch_path = args.mouth_patch_path.replace('.','')
dir_name = os.path.dirname(os.path.abspath(__file__))
dir_name = dir_name + mouth_patch_path
data_paths = [os.path.join(pth, f) for pth, dirs, files in os.walk(dir_name) for f in files]
npz_files = np.load(data_paths[0])['data']
data = preprocessing_func(npz_files) # data: TxHxW
# data = preprocessing_func(np.load(args.mouth_patch_path)['data']) # data: TxHxW
return data_paths[0], model(torch.FloatTensor(data)[None, None, :, :, :].cuda(), lengths=[data.shape[0]])
# return model(torch.FloatTensor(data)[None, None, :, :, :].cuda(), lengths=[data.shape[0]])
# ํ‰๊ฐ€
def evaluate(model, dset_loader, criterion, is_print=False):
model.eval() # evaluation ๊ณผ์ •์—์„œ ์‚ฌ์šฉํ•˜์ง€ ์•Š์•„์•ผ ํ•˜๋Š” layer๋“ค์„ ์•Œ์•„์„œ off ์‹œํ‚ค๋„๋ก ํ•˜๋Š” ํ•จ์ˆ˜
# running_loss = 0.
# running_corrects = 0.
prediction=''
# evaluation/validation ๊ณผ์ •์—์„  ๋ณดํ†ต model.eval()๊ณผ torch.no_grad()๋ฅผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•จ
with torch.no_grad():
inferences = []
for batch_idx, (input, lengths, labels) in enumerate(tqdm(dset_loader)):
# ๋ชจ๋ธ ์ƒ์„ฑ
# input ํ…์„œ์˜ ์ฐจ์›์„ ํ•˜๋‚˜ ๋” ๋Š˜๋ฆฌ๊ณ  gpu ์— ํ• ๋‹น
logits = model(input.unsqueeze(1).cuda(), lengths=lengths)
# _, preds = torch.max(F.softmax(logits, dim=1).data, dim=1) # softmax ์ ์šฉ ํ›„ ๊ฐ ์›์†Œ ์ค‘ ์ตœ๋Œ€๊ฐ’ ๊ฐ€์ ธ์˜ค๊ธฐ
# running_corrects += preds.eq(labels.cuda().view_as(preds)).sum().item() # ์ •ํ™•๋„ ๊ณ„์‚ฐ
# loss = criterion(logits, labels.cuda()) # loss ๊ณ„์‚ฐ
# running_loss += loss.item() * input.size(0) # loss.item(): loss ๊ฐ€ ๊ฐ–๊ณ  ์žˆ๋Š” scalar ๊ฐ’
# # ------------ Prediction, Confidence ์ถœ๋ ฅ ------------
probs = torch.nn.functional.softmax(logits, dim=-1)
probs = probs[0].detach().cpu().numpy()
label_path = args.label_path
with Path(label_path).open() as fp:
vocab = fp.readlines()
top = np.argmax(probs)
prediction = vocab[top].strip()
# confidence = np.round(probs[top], 3)
# inferences.append({
# 'prediction': prediction,
# 'confidence': confidence
# })
with open("/home/user/app/result/ho.txt", 'w') as f:
f.writelines(prediction)
if is_print:
print()
print(f'Prediction: {prediction}')
# print(f'Confidence: {confidence}')
print()
return prediction
# ------------ Prediction, Confidence ํ…์ŠคํŠธ ํŒŒ์ผ ์ €์žฅ ------------
# txt_save_path = str(args.save_dir) + f'/predict.txt'
# # ํŒŒ์ผ ์—†์„ ๊ฒฝ์šฐ
# if not os.path.exists(os.path.dirname(txt_save_path)):
# os.makedirs(os.path.dirname(txt_save_path)) # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
# with open(txt_save_path, 'w') as f:
# for inference in inferences:
# prediction = inference['prediction']
# confidence = inference['confidence']
# f.writelines(f'Prediction: {prediction}, Confidence: {confidence}\n')
# print('Test Dataset {} In Total \t CR: {}'.format( len(dset_loader.dataset), running_corrects/len(dset_loader.dataset))) # ๋ฐ์ดํ„ฐ๊ฐœ์ˆ˜, ์ •ํ™•๋„ ์ถœ๋ ฅ
# return running_corrects/len(dset_loader.dataset), running_loss/len(dset_loader.dataset), inferences # ์ •ํ™•๋„, loss, inferences ๋ฐ˜ํ™˜
# ๋ชจ๋ธ ํ•™์Šต
# def train(wandb, model, dset_loader, criterion, epoch, optimizer, logger):
# data_time = AverageMeter() # ํ‰๊ท , ํ˜„์žฌ๊ฐ’ ์ €์žฅ
# batch_time = AverageMeter() # ํ‰๊ท , ํ˜„์žฌ๊ฐ’ ์ €์žฅ
# lr = showLR(optimizer) # LR ๋ณ€ํ™”๊ฐ’
# # ๋กœ๊ฑฐ INFO ์ž‘์„ฑ
# logger.info('-' * 10)
# logger.info('Epoch {}/{}'.format(epoch, args.epochs - 1)) # epoch ์ž‘์„ฑ
# logger.info('Current learning rate: {}'.format(lr)) # learning rate ์ž‘์„ฑ
# model.train() # train mode
# running_loss = 0.
# running_corrects = 0.
# running_all = 0.
# end = time.time() # ํ˜„์žฌ ์‹œ๊ฐ
# for batch_idx, (input, lengths, labels) in enumerate(dset_loader):
# # measure data loading time
# data_time.update(time.time() - end) # ํ‰๊ท , ํ˜„์žฌ๊ฐ’ ์—…๋ฐ์ดํŠธ
# # --
# # mixup augmentation ๊ณ„์‚ฐ
# input, labels_a, labels_b, lam = mixup_data(input, labels, args.alpha)
# labels_a, labels_b = labels_a.cuda(), labels_b.cuda() # tensor ๋ฅผ gpu ์— ํ• ๋‹น
# # Pytorch์—์„œ๋Š” gradients๊ฐ’๋“ค์„ ์ถ”ํ›„์— backward๋ฅผ ํ•ด์ค„๋•Œ ๊ณ„์† ๋”ํ•ด์ฃผ๊ธฐ ๋•Œ๋ฌธ
# optimizer.zero_grad() # ํ•ญ์ƒ backpropagation์„ ํ•˜๊ธฐ์ „์— gradients๋ฅผ zero๋กœ ๋งŒ๋“ค์–ด์ฃผ๊ณ  ์‹œ์ž‘์„ ํ•ด์•ผ ํ•จ
# # ๋ชจ๋ธ ์ƒ์„ฑ
# # input ํ…์„œ์˜ ์ฐจ์›์„ ํ•˜๋‚˜ ๋” ๋Š˜๋ฆฌ๊ณ  gpu ์— ํ• ๋‹น
# logits = model(input.unsqueeze(1).cuda(), lengths=lengths)
# loss_func = mixup_criterion(labels_a, labels_b, lam) # mixup ์ ์šฉ
# loss = loss_func(criterion, logits) # loss ๊ณ„์‚ฐ
# loss.backward() # gradient ๊ณ„์‚ฐ
# optimizer.step() # ์ €์žฅ๋œ gradient ๊ฐ’์„ ์ด์šฉํ•˜์—ฌ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธ
# # measure elapsed time # ๊ฒฝ๊ณผ ์‹œ๊ฐ„ ์ธก์ •
# batch_time.update(time.time() - end) # ํ‰๊ท , ํ˜„์žฌ๊ฐ’ ์—…๋ฐ์ดํŠธ
# end = time.time() # ํ˜„์žฌ ์‹œ๊ฐ
# # -- compute running performance # ์ปดํ“จํŒ… ์‹คํ–‰ ์„ฑ๋Šฅ
# _, predicted = torch.max(F.softmax(logits, dim=1).data, dim=1) # softmax ์ ์šฉ ํ›„ ๊ฐ ์›์†Œ ์ค‘ ์ตœ๋Œ€๊ฐ’ ๊ฐ€์ ธ์˜ค๊ธฐ
# running_loss += loss.item()*input.size(0) # loss.item(): loss ๊ฐ€ ๊ฐ–๊ณ  ์žˆ๋Š” scalar ๊ฐ’
# running_corrects += lam * predicted.eq(labels_a.view_as(predicted)).sum().item() + (1 - lam) * predicted.eq(labels_b.view_as(predicted)).sum().item() # ์ •ํ™•๋„ ๊ณ„์‚ฐ
# running_all += input.size(0)
# # ------------------ wandb ๋กœ๊ทธ ์ž…๋ ฅ ------------------
# wandb.log({'loss': running_loss, 'acc': running_corrects}, step=epoch)
# # -- log intermediate results # ์ค‘๊ฐ„ ๊ฒฐ๊ณผ ๊ธฐ๋ก
# if batch_idx % args.interval == 0 or (batch_idx == len(dset_loader)-1):
# # ๋กœ๊ฑฐ INFO ์ž‘์„ฑ
# update_logger_batch( args, logger, dset_loader, batch_idx, running_loss, running_corrects, running_all, batch_time, data_time )
# return model # ๋ชจ๋ธ ๋ฐ˜ํ™˜
# model ์„ค์ •์— ๋Œ€ํ•œ json ์ž‘์„ฑ
def get_model_from_json():
# json ํŒŒ์ผ์ด ์žˆ๋Š”์ง€ ํ™•์ธ, ์—†์œผ๋ฉด AssertionError ๋ฉ”์‹œ์ง€๋ฅผ ๋„์›€
assert args.config_path.endswith('.json') and os.path.isfile(args.config_path), \
"'.json' config path does not exist. Path input: {}".format(args.config_path) # ์›ํ•˜๋Š” ์กฐ๊ฑด์˜ ๋ณ€์ˆ˜๊ฐ’์„ ๋ณด์ฆํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ
args_loaded = load_json( args.config_path) # json ์ฝ์–ด์˜ค๊ธฐ
args.backbone_type = args_loaded['backbone_type'] # json ์—์„œ backbone_type ๊ฐ€์ ธ์˜ค๊ธฐ
args.width_mult = args_loaded['width_mult'] # json ์—์„œ width_mult ๊ฐ€์ ธ์˜ค๊ธฐ
args.relu_type = args_loaded['relu_type'] # json ์—์„œ relu_type ๊ฐ€์ ธ์˜ค๊ธฐ
# TCN ์˜ต์…˜ ์„ค์ •
tcn_options = { 'num_layers': args_loaded['tcn_num_layers'],
'kernel_size': args_loaded['tcn_kernel_size'],
'dropout': args_loaded['tcn_dropout'],
'dwpw': args_loaded['tcn_dwpw'],
'width_mult': args_loaded['tcn_width_mult'],
}
# ๋ฆฝ๋ฆฌ๋”ฉ ๋ชจ๋ธ ์ƒ์„ฑ
model = Lipreading( modality=args.modality,
num_classes=args.num_classes,
tcn_options=tcn_options,
backbone_type=args.backbone_type,
relu_type=args.relu_type,
width_mult=args.width_mult,
extract_feats=args.extract_feats).cuda()
calculateNorm2(model) # ๋ชจ๋ธ ํ•™์Šต์ด ์ž˜ ์ง„ํ–‰๋˜๋Š”์ง€ ํ™•์ธ - ์ผ๋ฐ˜์ ์œผ๋กœ parameter norm(L2)์€ ํ•™์Šต์ด ์ง„ํ–‰๋ ์ˆ˜๋ก ์ปค์ ธ์•ผ ํ•จ
return model # ๋ชจ๋ธ ๋ฐ˜ํ™˜
# main() ํ•จ์ˆ˜
def main():
# wandb ์—ฐ๊ฒฐ
# wandb.init(project="Lipreading_using_TCN_running")
# wandb.config = {
# "learning_rate": args.lr,
# "epochs": args.epochs,
# "batch_size": args.batch_size
# }
# os.environ['CUDA_LAUNCH_BLOCKING']="1"
# os.environ["CUDA_VISIBLE_DEVICES"]="0" # GPU ์„ ํƒ ์ฝ”๋“œ ์ถ”๊ฐ€
# -- logging
save_path = get_save_folder( args) # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ
print("Model and log being saved in: {}".format(save_path)) # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ ์ถœ๋ ฅ
logger = get_logger(args, save_path) # ๋กœ๊ฑฐ ์ƒ์„ฑ ๋ฐ ์„ค์ •
ckpt_saver = CheckpointSaver(save_path) # ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ์„ค์ •
# -- get model
model = get_model_from_json()
# -- get dataset iterators
dset_loaders = get_data_loaders(args)
# -- get loss function
criterion = nn.CrossEntropyLoss()
# -- get optimizer
optimizer = get_optimizer(args, optim_policies=model.parameters())
# -- get learning rate scheduler
scheduler = CosineScheduler(args.lr, args.epochs) # ์ฝ”์‚ฌ์ธ ์Šค์ผ€์ค„๋Ÿฌ ์„ค์ •
if args.model_path:
# tar ํŒŒ์ผ์ด ์žˆ๋Š”์ง€ ํ™•์ธ, ์—†์œผ๋ฉด AssertionError ๋ฉ”์‹œ์ง€๋ฅผ ๋„์›€
assert args.model_path.endswith('.tar') and os.path.isfile(args.model_path), \
"'.tar' model path does not exist. Path input: {}".format(args.model_path) # ์›ํ•˜๋Š” ์กฐ๊ฑด์˜ ๋ณ€์ˆ˜๊ฐ’์„ ๋ณด์ฆํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ
# resume from checkpoint
if args.init_epoch > 0:
model, optimizer, epoch_idx, ckpt_dict = load_model(args.model_path, model, optimizer) # ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
args.init_epoch = epoch_idx # epoch ์„ค์ •
ckpt_saver.set_best_from_ckpt(ckpt_dict) # best ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
logger.info('Model and states have been successfully loaded from {}'.format( args.model_path )) # ๋กœ๊ฑฐ INFO ์ž‘์„ฑ
# init from trained model
else:
model = load_model(args.model_path, model, allow_size_mismatch=args.allow_size_mismatch) # ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
logger.info('Model has been successfully loaded from {}'.format( args.model_path )) # ๋กœ๊ฑฐ INFO ์ž‘์„ฑ
# feature extraction
if args.mouth_patch_path:
filename, embeddings = extract_feats(model)
filename = filename.split('/')[-1]
save_npz_path = os.path.join(args.mouth_embedding_out_path, filename)
# ExtractEmbedding ์€ ์ฝ”๋“œ ์ˆ˜์ •์ด ํ•„์š”ํ•จ!
save2npz(save_npz_path, data = embeddings.cpu().detach().numpy()) # npz ํŒŒ์ผ ์ €์žฅ
# save2npz( args.mouth_embedding_out_path, data = extract_feats(model).cpu().detach().numpy()) # npz ํŒŒ์ผ ์ €์žฅ
return
# if test-time, performance on test partition and exit. Otherwise, performance on validation and continue (sanity check for reload)
if args.test:
predicthi = evaluate(model, dset_loaders['test'], criterion, is_print=False) # ๋ชจ๋ธ ํ‰๊ฐ€
# logging_sentence = 'Test-time performance on partition {}: Loss: {:.4f}\tAcc:{:.4f}'.format( 'test', loss_avg_test, acc_avg_test)
# logger.info(logging_sentence) # ๋กœ๊ฑฐ INFO ์ž‘์„ฑ
return predicthi
# -- fix learning rate after loading the ckeckpoint (latency)
if args.model_path and args.init_epoch > 0:
scheduler.adjust_lr(optimizer, args.init_epoch-1) # learning rate ์—…๋ฐ์ดํŠธ
epoch = args.init_epoch # epoch ์ดˆ๊ธฐํ™”
while epoch < args.epochs:
model = train(wandb, model, dset_loaders['train'], criterion, epoch, optimizer, logger) # ๋ชจ๋ธ ํ•™์Šต
acc_avg_val, loss_avg_val, inferences = evaluate(model, dset_loaders['val'], criterion) # ๋ชจ๋ธ ํ‰๊ฐ€
logger.info('{} Epoch:\t{:2}\tLoss val: {:.4f}\tAcc val:{:.4f}, LR: {}'.format('val', epoch, loss_avg_val, acc_avg_val, showLR(optimizer))) # ๋กœ๊ฑฐ INFO ์ž‘์„ฑ
# -- save checkpoint # ์ฒดํฌํฌ์ธํŠธ ์ƒํƒœ ๊ธฐ๋ก
save_dict = {
'epoch_idx': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
ckpt_saver.save(save_dict, acc_avg_val) # ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
scheduler.adjust_lr(optimizer, epoch) # learning rate ์—…๋ฐ์ดํŠธ
epoch += 1
# -- evaluate best-performing epoch on test partition # test ๋ฐ์ดํ„ฐ๋กœ best ์„ฑ๋Šฅ์˜ epoch ํ‰๊ฐ€
best_fp = os.path.join(ckpt_saver.save_dir, ckpt_saver.best_fn) # best ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ
_ = load_model(best_fp, model) # ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
acc_avg_test, loss_avg_test, inferences = evaluate(model, dset_loaders['test'], criterion) # ๋ชจ๋ธ ํ‰๊ฐ€
logger.info('Test time performance of best epoch: {} (loss: {})'.format(acc_avg_test, loss_avg_test)) # ๋กœ๊ฑฐ INFO ์ž‘์„ฑ
torch.cuda.empty_cache() # GPU ์บ์‹œ ๋ฐ์ดํ„ฐ ์‚ญ์ œ
# ํ•ด๋‹น ๋ชจ๋“ˆ์ด ์ž„ํฌํŠธ๋œ ๊ฒฝ์šฐ๊ฐ€ ์•„๋‹ˆ๋ผ ์ธํ„ฐํ”„๋ฆฌํ„ฐ์—์„œ ์ง์ ‘ ์‹คํ–‰๋œ ๊ฒฝ์šฐ์—๋งŒ, if๋ฌธ ์ดํ•˜์˜ ์ฝ”๋“œ๋ฅผ ๋Œ๋ฆฌ๋ผ๋Š” ๋ช…๋ น
# => main.py ์‹คํ–‰ํ•  ๊ฒฝ์šฐ ์ œ์ผ ๋จผ์ € ํ˜ธ์ถœ๋˜๋Š” ๋ถ€๋ถ„
if __name__ == '__main__': # ํ˜„์žฌ ์Šคํฌ๋ฆฝํŠธ ํŒŒ์ผ์ด ์‹คํ–‰๋˜๋Š” ์ƒํƒœ ํŒŒ์•…
main() # main() ํ•จ์ˆ˜ ํ˜ธ์ถœ