Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch, gc | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
import numpy as np | |
from pathlib import Path | |
from models.ormbg import ORMBG | |
from skimage import io | |
from basics import f1_mae_torch | |
from data_loader_cache import ( | |
get_im_gt_name_dict, | |
create_dataloaders, | |
GOSGridDropout, | |
GOSRandomHFlip, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): | |
net.eval() | |
print("Validating...") | |
epoch_num = hypar["max_epoch_num"] | |
val_loss = 0.0 | |
tar_loss = 0.0 | |
val_cnt = 0.0 | |
tmp_f1 = [] | |
tmp_mae = [] | |
tmp_time = [] | |
start_valid = time.time() | |
for k in range(len(valid_dataloaders)): | |
valid_dataloader = valid_dataloaders[k] | |
valid_dataset = valid_datasets[k] | |
val_num = valid_dataset.__len__() | |
mybins = np.arange(0, 256) | |
PRE = np.zeros((val_num, len(mybins) - 1)) | |
REC = np.zeros((val_num, len(mybins) - 1)) | |
F1 = np.zeros((val_num, len(mybins) - 1)) | |
MAE = np.zeros((val_num)) | |
for i_val, data_val in enumerate(valid_dataloader): | |
val_cnt = val_cnt + 1.0 | |
imidx_val, inputs_val, labels_val, shapes_val = ( | |
data_val["imidx"], | |
data_val["image"], | |
data_val["label"], | |
data_val["shape"], | |
) | |
if hypar["model_digit"] == "full": | |
inputs_val = inputs_val.type(torch.FloatTensor) | |
labels_val = labels_val.type(torch.FloatTensor) | |
else: | |
inputs_val = inputs_val.type(torch.HalfTensor) | |
labels_val = labels_val.type(torch.HalfTensor) | |
# wrap them in Variable | |
if torch.cuda.is_available(): | |
inputs_val_v, labels_val_v = Variable( | |
inputs_val.cuda(), requires_grad=False | |
), Variable(labels_val.cuda(), requires_grad=False) | |
else: | |
inputs_val_v, labels_val_v = Variable( | |
inputs_val, requires_grad=False | |
), Variable(labels_val, requires_grad=False) | |
t_start = time.time() | |
ds_val = net(inputs_val_v)[0] | |
t_end = time.time() - t_start | |
tmp_time.append(t_end) | |
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v) | |
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) | |
# compute F measure | |
for t in range(hypar["batch_size_valid"]): | |
i_test = imidx_val[t].data.numpy() | |
pred_val = ds_val[0][t, :, :, :] # B x 1 x H x W | |
## recover the prediction spatial size to the orignal image size | |
pred_val = torch.squeeze( | |
F.upsample( | |
torch.unsqueeze(pred_val, 0), | |
(shapes_val[t][0], shapes_val[t][1]), | |
mode="bilinear", | |
) | |
) | |
# pred_val = normPRED(pred_val) | |
ma = torch.max(pred_val) | |
mi = torch.min(pred_val) | |
pred_val = (pred_val - mi) / (ma - mi) # max = 1 | |
if len(valid_dataset.dataset["ori_gt_path"]) != 0: | |
gt = np.squeeze( | |
io.imread(valid_dataset.dataset["ori_gt_path"][i_test]) | |
) # max = 255 | |
if gt.max() == 1: | |
gt = gt * 255 | |
else: | |
gt = np.zeros((shapes_val[t][0], shapes_val[t][1])) | |
with torch.no_grad(): | |
gt = torch.tensor(gt).to(device) | |
pre, rec, f1, mae = f1_mae_torch( | |
pred_val * 255, gt, valid_dataset, i_test, mybins, hypar | |
) | |
PRE[i_test, :] = pre | |
REC[i_test, :] = rec | |
F1[i_test, :] = f1 | |
MAE[i_test] = mae | |
del ds_val, gt | |
gc.collect() | |
torch.cuda.empty_cache() | |
# if(loss_val.data[0]>1): | |
val_loss += loss_val.item() # data[0] | |
tar_loss += loss2_val.item() # data[0] | |
print( | |
"[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f" | |
% ( | |
i_val, | |
val_num, | |
val_loss / (i_val + 1), | |
tar_loss / (i_val + 1), | |
np.amax(F1[i_test, :]), | |
MAE[i_test], | |
t_end, | |
) | |
) | |
del loss2_val, loss_val | |
print("============================") | |
PRE_m = np.mean(PRE, 0) | |
REC_m = np.mean(REC, 0) | |
f1_m = (1 + 0.3) * PRE_m * REC_m / (0.3 * PRE_m + REC_m + 1e-8) | |
tmp_f1.append(np.amax(f1_m)) | |
tmp_mae.append(np.mean(MAE)) | |
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time | |
def train( | |
net, | |
optimizer, | |
train_dataloaders, | |
train_datasets, | |
valid_dataloaders, | |
valid_datasets, | |
hypar, | |
): | |
model_path = hypar["model_path"] | |
model_save_fre = hypar["model_save_fre"] | |
max_ite = hypar["max_ite"] | |
batch_size_train = hypar["batch_size_train"] | |
batch_size_valid = hypar["batch_size_valid"] | |
if not os.path.exists(model_path): | |
os.mkdir(model_path) | |
ite_num = hypar["start_ite"] # count the toal iteration number | |
ite_num4val = 0 # | |
running_loss = 0.0 # count the toal loss | |
running_tar_loss = 0.0 # count the target output loss | |
last_f1 = [0 for x in range(len(valid_dataloaders))] | |
train_num = train_datasets[0].__len__() | |
net.train() | |
start_last = time.time() | |
gos_dataloader = train_dataloaders[0] | |
epoch_num = hypar["max_epoch_num"] | |
notgood_cnt = 0 | |
for epoch in range(epoch_num): | |
for i, data in enumerate(gos_dataloader): | |
if ite_num >= max_ite: | |
print("Training Reached the Maximal Iteration Number ", max_ite) | |
exit() | |
# start_read = time.time() | |
ite_num = ite_num + 1 | |
ite_num4val = ite_num4val + 1 | |
# get the inputs | |
inputs, labels = data["image"], data["label"] | |
if hypar["model_digit"] == "full": | |
inputs = inputs.type(torch.FloatTensor) | |
labels = labels.type(torch.FloatTensor) | |
else: | |
inputs = inputs.type(torch.HalfTensor) | |
labels = labels.type(torch.HalfTensor) | |
# wrap them in Variable | |
if torch.cuda.is_available(): | |
inputs_v, labels_v = Variable( | |
inputs.cuda(), requires_grad=False | |
), Variable(labels.cuda(), requires_grad=False) | |
else: | |
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable( | |
labels, requires_grad=False | |
) | |
# y zero the parameter gradients | |
start_inf_loss_back = time.time() | |
optimizer.zero_grad() | |
ds, _ = net(inputs_v) | |
loss2, loss = net.compute_loss(ds, labels_v) | |
loss.backward() | |
optimizer.step() | |
# # print statistics | |
running_loss += loss.item() | |
running_tar_loss += loss2.item() | |
# del outputs, loss | |
del ds, loss2, loss | |
end_inf_loss_back = time.time() - start_inf_loss_back | |
print( | |
">>>" | |
+ model_path.split("/")[-1] | |
+ " - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" | |
% ( | |
epoch + 1, | |
epoch_num, | |
(i + 1) * batch_size_train, | |
train_num, | |
ite_num, | |
running_loss / ite_num4val, | |
running_tar_loss / ite_num4val, | |
time.time() - start_last, | |
time.time() - start_last - end_inf_loss_back, | |
) | |
) | |
start_last = time.time() | |
if ite_num % model_save_fre == 0: # validate every 2000 iterations | |
notgood_cnt += 1 | |
net.eval() | |
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid( | |
net, valid_dataloaders, valid_datasets, hypar, epoch | |
) | |
net.train() # resume train | |
tmp_out = 0 | |
print("last_f1:", last_f1) | |
print("tmp_f1:", tmp_f1) | |
for fi in range(len(last_f1)): | |
if tmp_f1[fi] > last_f1[fi]: | |
tmp_out = 1 | |
print("tmp_out:", tmp_out) | |
if tmp_out: | |
notgood_cnt = 0 | |
last_f1 = tmp_f1 | |
tmp_f1_str = [str(round(f1x, 4)) for f1x in tmp_f1] | |
tmp_mae_str = [str(round(mx, 4)) for mx in tmp_mae] | |
maxf1 = "_".join(tmp_f1_str) | |
meanM = "_".join(tmp_mae_str) | |
# .cpu().detach().numpy() | |
model_name = ( | |
"/gpu_itr_" | |
+ str(ite_num) | |
+ "_traLoss_" | |
+ str(np.round(running_loss / ite_num4val, 4)) | |
+ "_traTarLoss_" | |
+ str(np.round(running_tar_loss / ite_num4val, 4)) | |
+ "_valLoss_" | |
+ str(np.round(val_loss / (i_val + 1), 4)) | |
+ "_valTarLoss_" | |
+ str(np.round(tar_loss / (i_val + 1), 4)) | |
+ "_maxF1_" | |
+ maxf1 | |
+ "_mae_" | |
+ meanM | |
+ "_time_" | |
+ str( | |
np.round(np.mean(np.array(tmp_time)) / batch_size_valid, 6) | |
) | |
+ ".pth" | |
) | |
torch.save(net.state_dict(), model_path + model_name) | |
running_loss = 0.0 | |
running_tar_loss = 0.0 | |
ite_num4val = 0 | |
if notgood_cnt >= hypar["early_stop"]: | |
print( | |
"No improvements in the last " | |
+ str(notgood_cnt) | |
+ " validation periods, so training stopped !" | |
) | |
exit() | |
print("Training Reaches The Maximum Epoch Number") | |
def main(train_datasets, valid_datasets, hypar): | |
print("--- create training dataloader ---") | |
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") | |
## build dataloader for training datasets | |
train_dataloaders, train_datasets = create_dataloaders( | |
train_nm_im_gt_list, | |
cache_size=hypar["cache_size"], | |
cache_boost=hypar["cache_boost_train"], | |
my_transforms=[GOSGridDropout(), GOSRandomHFlip()], | |
batch_size=hypar["batch_size_train"], | |
shuffle=True, | |
) | |
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") | |
valid_dataloaders, valid_datasets = create_dataloaders( | |
valid_nm_im_gt_list, | |
cache_size=hypar["cache_size"], | |
cache_boost=hypar["cache_boost_valid"], | |
my_transforms=[], | |
batch_size=hypar["batch_size_valid"], | |
shuffle=False, | |
) | |
net = hypar["model"] | |
if hypar["model_digit"] == "half": | |
net.half() | |
for layer in net.modules(): | |
if isinstance(layer, nn.BatchNorm2d): | |
layer.float() | |
if torch.cuda.is_available(): | |
net.cuda() | |
if hypar["restore_model"] != "": | |
print("restore model from:") | |
print(hypar["model_path"] + "/" + hypar["restore_model"]) | |
if torch.cuda.is_available(): | |
net.load_state_dict( | |
torch.load(hypar["model_path"] + "/" + hypar["restore_model"]) | |
) | |
else: | |
net.load_state_dict( | |
torch.load( | |
hypar["model_path"] + "/" + hypar["restore_model"], | |
map_location="cpu", | |
) | |
) | |
optimizer = optim.Adam( | |
net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0 | |
) | |
train( | |
net, | |
optimizer, | |
train_dataloaders, | |
train_datasets, | |
valid_dataloaders, | |
valid_datasets, | |
hypar, | |
) | |
if __name__ == "__main__": | |
output_model_folder = "saved_models" | |
Path(output_model_folder).mkdir(parents=True, exist_ok=True) | |
train_datasets, valid_datasets = [], [] | |
dataset_1, dataset_1 = {}, {} | |
dataset_training = { | |
"name": "ormbg-training", | |
"im_dir": str(Path("dataset", "training", "im")), | |
"gt_dir": str(Path("dataset", "training", "gt")), | |
"im_ext": ".png", | |
"gt_ext": ".png", | |
"cache_dir": str(Path("cache", "teacher", "training")), | |
} | |
dataset_validation = { | |
"name": "ormbg-training", | |
"im_dir": str(Path("dataset", "validation", "im")), | |
"gt_dir": str(Path("dataset", "validation", "gt")), | |
"im_ext": ".png", | |
"gt_ext": ".png", | |
"cache_dir": str(Path("cache", "teacher", "validation")), | |
} | |
train_datasets = [dataset_training] | |
valid_datasets = [dataset_validation] | |
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing --------------- | |
hypar = {} | |
hypar["model"] = ORMBG() | |
hypar["seed"] = 0 | |
## model weights path | |
hypar["model_path"] = "saved_models" | |
## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing | |
hypar["restore_model"] = "" | |
## start iteration for the training, can be changed to match the restored training process | |
hypar["start_ite"] = 0 | |
## indicates "half" or "full" accuracy of float number | |
hypar["model_digit"] = "full" | |
## To handle large size input images, which take a lot of time for loading in training, | |
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file | |
hypar["cache_size"] = [ | |
1024, | |
1024, | |
] | |
## cached input spatial resolution, can be configured into different size | |
## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM | |
hypar["cache_boost_train"] = False | |
## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM | |
hypar["cache_boost_valid"] = False | |
## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10. | |
hypar["early_stop"] = 20 | |
## valid and save model weights every 2000 iterations | |
hypar["model_save_fre"] = 2000 | |
## batch size for training | |
hypar["batch_size_train"] = 8 | |
## batch size for validation and inferencing | |
hypar["batch_size_valid"] = 1 | |
## if early stop couldn't stop the training process, stop it by the max_ite_num | |
hypar["max_ite"] = 10000000 | |
## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num | |
hypar["max_epoch_num"] = 1000000 | |
main(train_datasets, valid_datasets, hypar=hypar) | |