File size: 4,985 Bytes
29858c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import copy
import random
import argparse
import time
import numpy as np
from PIL import Image
import scipy.io as scio
import scipy.misc
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
from torch.backends import cudnn
from data_controller import SegDataset
from loss import Loss
from segnet import SegNet as segnet
import sys
sys.path.append("..")
from lib.utils import setup_logger
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', default='/home/data1/jeremy/YCB_Video_Dataset', help="dataset root dir (''YCB_Video Dataset'')")
parser.add_argument('--batch_size', default=3, help="batch size")
parser.add_argument('--n_epochs', default=600, help="epochs to train")
parser.add_argument('--workers', type=int, default=10, help='number of data loading workers')
parser.add_argument('--lr', default=0.0001, help="learning rate")
parser.add_argument('--logs_path', default='logs/', help="path to save logs")
parser.add_argument('--model_save_path', default='trained_models/', help="path to save models")
parser.add_argument('--log_dir', default='logs/', help="path to save logs")
parser.add_argument('--resume_model', default='', help="resume model name")
opt = parser.parse_args()
if __name__ == '__main__':
opt.manualSeed = random.randint(1, 10000)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers))
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers))
print(len(dataset), len(test_dataset))
model = segnet()
model = model.cuda()
if opt.resume_model != '':
checkpoint = torch.load('{0}/{1}'.format(opt.model_save_path, opt.resume_model))
model.load_state_dict(checkpoint)
for log in os.listdir(opt.log_dir):
os.remove(os.path.join(opt.log_dir, log))
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
criterion = Loss()
best_val_cost = np.Inf
st_time = time.time()
for epoch in range(1, opt.n_epochs):
model.train()
train_all_cost = 0.0
train_time = 0
logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
for i, data in enumerate(dataloader, 0):
rgb, target = data
rgb, target = Variable(rgb).cuda(), Variable(target).cuda()
semantic = model(rgb)
optimizer.zero_grad()
semantic_loss = criterion(semantic, target)
train_all_cost += semantic_loss.item()
semantic_loss.backward()
optimizer.step()
logger.info('Train time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), train_time, semantic_loss.item()))
if train_time != 0 and train_time % 1000 == 0:
torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_current.pth'))
train_time += 1
train_all_cost = train_all_cost / train_time
logger.info('Train Finish Avg CEloss: {0}'.format(train_all_cost))
model.eval()
test_all_cost = 0.0
test_time = 0
logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
for j, data in enumerate(test_dataloader, 0):
rgb, target = data
rgb, target = Variable(rgb).cuda(), Variable(target).cuda()
semantic = model(rgb)
semantic_loss = criterion(semantic, target)
test_all_cost += semantic_loss.item()
test_time += 1
logger.info('Test time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_time, semantic_loss.item()))
test_all_cost = test_all_cost / test_time
logger.info('Test Finish Avg CEloss: {0}'.format(test_all_cost))
if test_all_cost <= best_val_cost:
best_val_cost = test_all_cost
torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_{}_{}.pth'.format(epoch, test_all_cost)))
print('----------->BEST SAVED<-----------')
|