anjali2002's picture
Initial commit of EasyOCR model
b4959be
import os
import sys
import time
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from utils import CTCLabelConverter, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from test import validation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def count_parameters(model):
print("Modules, Parameters")
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
#table.add_row([name, param])
total_params+=param
print(name, param)
print(f"Total Trainable Params: {total_params}")
return total_params
def train(opt, show_number = 2, amp=False):
""" dataset preparation """
if not opt.data_filtering_off:
print('Filtering the images containing characters which are not in opt.character')
print('Filtering the images whose label is longer than opt.batch_max_length')
opt.select_data = opt.select_data.split('-')
opt.batch_ratio = opt.batch_ratio.split('-')
train_dataset = Batch_Balanced_Dataset(opt)
log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8")
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust)
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=min(32, opt.batch_size),
shuffle=True, # 'True' to check training progress with validation function.
num_workers=int(opt.workers), prefetch_factor=512,
collate_fn=AlignCollate_valid, pin_memory=True)
log.write(valid_dataset_log)
print('-' * 80)
log.write('-' * 80 + '\n')
log.close()
""" model configuration """
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
if opt.rgb:
opt.input_channel = 3
model = Model(opt)
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
if opt.saved_model != '':
pretrained_dict = torch.load(opt.saved_model)
if opt.new_prediction:
model.Prediction = nn.Linear(model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight']))
model = torch.nn.DataParallel(model).to(device)
print(f'loading pretrained model from {opt.saved_model}')
if opt.FT:
model.load_state_dict(pretrained_dict, strict=False)
else:
model.load_state_dict(pretrained_dict)
if opt.new_prediction:
model.module.Prediction = nn.Linear(model.module.SequenceModeling_output, opt.num_class)
for name, param in model.module.Prediction.named_parameters():
if 'bias' in name:
init.constant_(param, 0.0)
elif 'weight' in name:
init.kaiming_normal_(param)
model = model.to(device)
else:
# weight initialization
for name, param in model.named_parameters():
if 'localization_fc2' in name:
print(f'Skip {name} as it is already initialized')
continue
try:
if 'bias' in name:
init.constant_(param, 0.0)
elif 'weight' in name:
init.kaiming_normal_(param)
except Exception as e: # for batchnorm.
if 'weight' in name:
param.data.fill_(1)
continue
model = torch.nn.DataParallel(model).to(device)
model.train()
print("Model:")
print(model)
count_parameters(model)
""" setup loss """
if 'CTC' in opt.Prediction:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
# loss averager
loss_avg = Averager()
# freeze some layers
try:
if opt.freeze_FeatureFxtraction:
for param in model.module.FeatureExtraction.parameters():
param.requires_grad = False
if opt.freeze_SequenceModeling:
for param in model.module.SequenceModeling.parameters():
param.requires_grad = False
except:
pass
# filter that only require gradient decent
filtered_parameters = []
params_num = []
for p in filter(lambda p: p.requires_grad, model.parameters()):
filtered_parameters.append(p)
params_num.append(np.prod(p.size()))
print('Trainable params num : ', sum(params_num))
# [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
# setup optimizer
if opt.optim=='adam':
#optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
optimizer = optim.Adam(filtered_parameters)
else:
optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
print("Optimizer:")
print(optimizer)
""" final options """
# print(opt)
with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file:
opt_log = '------------ Options -------------\n'
args = vars(opt)
for k, v in args.items():
opt_log += f'{str(k)}: {str(v)}\n'
opt_log += '---------------------------------------\n'
print(opt_log)
opt_file.write(opt_log)
""" start training """
start_iter = 0
if opt.saved_model != '':
try:
start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
print(f'continue to train, start_iter: {start_iter}')
except:
pass
start_time = time.time()
best_accuracy = -1
best_norm_ED = -1
i = start_iter
scaler = GradScaler()
t1= time.time()
while(True):
# train part
optimizer.zero_grad(set_to_none=True)
if amp:
with autocast():
image_tensors, labels = train_dataset.get_batch()
image = image_tensors.to(device)
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
batch_size = image.size(0)
if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2)
torch.backends.cudnn.enabled = False
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
torch.backends.cudnn.enabled = True
else:
preds = model(image, text[:, :-1]) # align with Attention.forward
target = text[:, 1:] # without [GO] Symbol
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
scaler.scale(cost).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
image_tensors, labels = train_dataset.get_batch()
image = image_tensors.to(device)
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
batch_size = image.size(0)
if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2)
torch.backends.cudnn.enabled = False
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
torch.backends.cudnn.enabled = True
else:
preds = model(image, text[:, :-1]) # align with Attention.forward
target = text[:, 1:] # without [GO] Symbol
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
cost.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
optimizer.step()
loss_avg.add(cost)
# validation part
if (i % opt.valInterval == 0) and (i!=0):
print('training time: ', time.time()-t1)
t1=time.time()
elapsed_time = time.time() - start_time
# for log
with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log:
model.eval()
with torch.no_grad():
valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
model.train()
# training loss and validation loss
loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
loss_avg.reset()
current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'
# keep best accuracy model (on valid dataset)
if current_accuracy > best_accuracy:
best_accuracy = current_accuracy
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
if current_norm_ED > best_norm_ED:
best_norm_ED = current_norm_ED
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'
loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
print(loss_model_log)
log.write(loss_model_log + '\n')
# show some predicted results
dashed_line = '-' * 80
head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
#show_number = min(show_number, len(labels))
start = random.randint(0,len(labels) - show_number )
for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]):
if 'Attn' in opt.Prediction:
gt = gt[:gt.find('[s]')]
pred = pred[:pred.find('[s]')]
predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
predicted_result_log += f'{dashed_line}'
print(predicted_result_log)
log.write(predicted_result_log + '\n')
print('validation time: ', time.time()-t1)
t1=time.time()
# save model per 1e+4 iter.
if (i + 1) % 1e+4 == 0:
torch.save(
model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
if i == opt.num_iter:
print('end the training')
sys.exit()
i += 1