Shr3ezy's picture
Uploading all files
29858c0 verified
import os
import copy
import random
import argparse
import time
import numpy as np
from PIL import Image
import scipy.io as scio
import scipy.misc
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
from torch.backends import cudnn
from data_controller import SegDataset
from loss import Loss
from segnet import SegNet as segnet
import sys
sys.path.append("..")
from lib.utils import setup_logger
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', default='/home/data1/jeremy/YCB_Video_Dataset', help="dataset root dir (''YCB_Video Dataset'')")
parser.add_argument('--batch_size', default=3, help="batch size")
parser.add_argument('--n_epochs', default=600, help="epochs to train")
parser.add_argument('--workers', type=int, default=10, help='number of data loading workers')
parser.add_argument('--lr', default=0.0001, help="learning rate")
parser.add_argument('--logs_path', default='logs/', help="path to save logs")
parser.add_argument('--model_save_path', default='trained_models/', help="path to save models")
parser.add_argument('--log_dir', default='logs/', help="path to save logs")
parser.add_argument('--resume_model', default='', help="resume model name")
opt = parser.parse_args()
if __name__ == '__main__':
opt.manualSeed = random.randint(1, 10000)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers))
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers))
print(len(dataset), len(test_dataset))
model = segnet()
model = model.cuda()
if opt.resume_model != '':
checkpoint = torch.load('{0}/{1}'.format(opt.model_save_path, opt.resume_model))
model.load_state_dict(checkpoint)
for log in os.listdir(opt.log_dir):
os.remove(os.path.join(opt.log_dir, log))
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
criterion = Loss()
best_val_cost = np.Inf
st_time = time.time()
for epoch in range(1, opt.n_epochs):
model.train()
train_all_cost = 0.0
train_time = 0
logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
for i, data in enumerate(dataloader, 0):
rgb, target = data
rgb, target = Variable(rgb).cuda(), Variable(target).cuda()
semantic = model(rgb)
optimizer.zero_grad()
semantic_loss = criterion(semantic, target)
train_all_cost += semantic_loss.item()
semantic_loss.backward()
optimizer.step()
logger.info('Train time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), train_time, semantic_loss.item()))
if train_time != 0 and train_time % 1000 == 0:
torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_current.pth'))
train_time += 1
train_all_cost = train_all_cost / train_time
logger.info('Train Finish Avg CEloss: {0}'.format(train_all_cost))
model.eval()
test_all_cost = 0.0
test_time = 0
logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
for j, data in enumerate(test_dataloader, 0):
rgb, target = data
rgb, target = Variable(rgb).cuda(), Variable(target).cuda()
semantic = model(rgb)
semantic_loss = criterion(semantic, target)
test_all_cost += semantic_loss.item()
test_time += 1
logger.info('Test time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_time, semantic_loss.item()))
test_all_cost = test_all_cost / test_time
logger.info('Test Finish Avg CEloss: {0}'.format(test_all_cost))
if test_all_cost <= best_val_cost:
best_val_cost = test_all_cost
torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_{}_{}.pth'.format(epoch, test_all_cost)))
print('----------->BEST SAVED<-----------')