import os import copy import random import argparse import time import numpy as np from PIL import Image import scipy.io as scio import scipy.misc import torch from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.autograd import Variable import torch.optim as optim import torch.nn as nn from torch.backends import cudnn from data_controller import SegDataset from loss import Loss from segnet import SegNet as segnet import sys sys.path.append("..") from lib.utils import setup_logger parser = argparse.ArgumentParser() parser.add_argument('--dataset_root', default='/home/data1/jeremy/YCB_Video_Dataset', help="dataset root dir (''YCB_Video Dataset'')") parser.add_argument('--batch_size', default=3, help="batch size") parser.add_argument('--n_epochs', default=600, help="epochs to train") parser.add_argument('--workers', type=int, default=10, help='number of data loading workers') parser.add_argument('--lr', default=0.0001, help="learning rate") parser.add_argument('--logs_path', default='logs/', help="path to save logs") parser.add_argument('--model_save_path', default='trained_models/', help="path to save models") parser.add_argument('--log_dir', default='logs/', help="path to save logs") parser.add_argument('--resume_model', default='', help="resume model name") opt = parser.parse_args() if __name__ == '__main__': opt.manualSeed = random.randint(1, 10000) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers)) print(len(dataset), len(test_dataset)) model = segnet() model = model.cuda() if opt.resume_model != '': checkpoint = torch.load('{0}/{1}'.format(opt.model_save_path, opt.resume_model)) model.load_state_dict(checkpoint) for log in os.listdir(opt.log_dir): os.remove(os.path.join(opt.log_dir, log)) optimizer = optim.Adam(model.parameters(), lr=opt.lr) criterion = Loss() best_val_cost = np.Inf st_time = time.time() for epoch in range(1, opt.n_epochs): model.train() train_all_cost = 0.0 train_time = 0 logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch)) logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started')) for i, data in enumerate(dataloader, 0): rgb, target = data rgb, target = Variable(rgb).cuda(), Variable(target).cuda() semantic = model(rgb) optimizer.zero_grad() semantic_loss = criterion(semantic, target) train_all_cost += semantic_loss.item() semantic_loss.backward() optimizer.step() logger.info('Train time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), train_time, semantic_loss.item())) if train_time != 0 and train_time % 1000 == 0: torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_current.pth')) train_time += 1 train_all_cost = train_all_cost / train_time logger.info('Train Finish Avg CEloss: {0}'.format(train_all_cost)) model.eval() test_all_cost = 0.0 test_time = 0 logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch)) logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started')) for j, data in enumerate(test_dataloader, 0): rgb, target = data rgb, target = Variable(rgb).cuda(), Variable(target).cuda() semantic = model(rgb) semantic_loss = criterion(semantic, target) test_all_cost += semantic_loss.item() test_time += 1 logger.info('Test time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_time, semantic_loss.item())) test_all_cost = test_all_cost / test_time logger.info('Test Finish Avg CEloss: {0}'.format(test_all_cost)) if test_all_cost <= best_val_cost: best_val_cost = test_all_cost torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_{}_{}.pth'.format(epoch, test_all_cost))) print('----------->BEST SAVED<-----------')