diff --git a/.gitattributes b/.gitattributes
deleted file mode 100644
index 818d649bf21cdef29b21f885c8f770f9baa1714e..0000000000000000000000000000000000000000
--- a/.gitattributes
+++ /dev/null
@@ -1,31 +0,0 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 203d4cf2c27671ecfa10eea1fe857288b9eb7130..2a15f89cdca1db0833c2a31b9bcf815544d0f402 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,136 @@
----
-title: Pixelplanetocr
-emoji: 🦀
-colorFrom: gray
-colorTo: purple
-sdk: gradio
-sdk_version: 3.3.1
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# IterVM: Iterative Vision Modeling Module for Scene Text Recognition
+
+The official code of [IterNet](https://arxiv.org/abs/2204.02630).
+
+We propose IterVM, an iterative approach for visual feature extraction which can significantly improve scene text recognition accuracy.
+IterVM repeatedly uses the high-level visual feature extracted at the previous iteration to enhance the multi-level features extracted at the subsequent iteration.
+
+
+
+
+
+## Runtime Environment
+```
+pip install -r requirements.txt
+```
+Note: `fastai==1.0.60` is required.
+
+## Datasets
+Training datasets (Click to expand)
+ 1. [MJSynth](http://www.robots.ox.ac.uk/~vgg/data/text/) (MJ):
+ - Use `tools/create_lmdb_dataset.py` to convert images into LMDB dataset
+ - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
+ 2. [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) (ST):
+ - Use `tools/crop_by_word_bb.py` to crop images from original [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) dataset, and convert images into LMDB dataset by `tools/create_lmdb_dataset.py`
+ - [LMDB dataset BaiduNetdisk(passwd:n23k)](https://pan.baidu.com/s/1mgnTiyoR8f6Cm655rFI4HQ)
+ 3. [WikiText103](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), which is only used for pre-trainig language models:
+ - Use `notebooks/prepare_wikitext103.ipynb` to convert text into CSV format.
+ - [CSV dataset BaiduNetdisk(passwd:dk01)](https://pan.baidu.com/s/1yabtnPYDKqhBb_Ie9PGFXA)
+Evaluation datasets (Click to expand)
+- Evaluation datasets, LMDB datasets can be downloaded from [BaiduNetdisk(passwd:1dbv)](https://pan.baidu.com/s/1RUg3Akwp7n8kZYJ55rU5LQ), [GoogleDrive](https://drive.google.com/file/d/1dTI0ipu14Q1uuK4s4z32DqbqF3dJPdkk/view?usp=sharing).
+ 1. ICDAR 2013 (IC13)
+ 2. ICDAR 2015 (IC15)
+ 3. IIIT5K Words (IIIT)
+ 4. Street View Text (SVT)
+ 5. Street View Text-Perspective (SVTP)
+ 6. CUTE80 (CUTE)
+The structure of `data` directory (Click to expand)
+- The structure of `data` directory is
+ ```
+ data
+ ├── charset_36.txt
+ ├── evaluation
+ │  ├── CUTE80
+ │  ├── IC13_857
+ │  ├── IC15_1811
+ │  ├── IIIT5k_3000
+ │  ├── SVT
+ │  └── SVTP
+ ├── training
+ │  ├── MJ
+ │  │  ├── MJ_test
+ │  │  ├── MJ_train
+ │  │  └── MJ_valid
+ │  └── ST
+ ├── WikiText-103.csv
+ └── WikiText-103_eval_d1.csv
+ ```
+](https://colab.research.google.com/drive/1XmZGJzFF95uafmARtJMudPLLKBO2eXLv?usp=sharing)
+
+```
+python demo.py --config=configs/train_iternet.yaml --input=figures/demo
+```
+Additional flags:
+- `--config /path/to/config` set the path of configuration file
+- `--input /path/to/image-directory` set the path of image directory or wildcard path, e.g, `--input='figs/test/*.png'`
+- `--checkpoint /path/to/checkpoint` set the path of trained model
+- `--cuda [-1|0|1|2|3...]` set the cuda id, by default -1 is set and stands for cpu
+- `--model_eval [alignment|vision]` which sub-model to use
+- `--image_only` disable dumping visualization of attention masks
+
+
+## Citation
+If you find our method useful for your reserach, please cite
+```bash
+@article{chu2022itervm,
+ title={IterVM: Iterative Vision Modeling Module for Scene Text Recognition},
+ author={Chu, Xiaojie and Wang, Yongtao},
+ journal={arXiv preprint arXiv:2204.02630},
+ year={2022}
+}
+ ```
+
+## License
+The project is only free for academic research purposes, but needs authorization for commerce. For commerce permission, please contact wyt@pku.edu.cn.
+
+## Acknowledgements
+This project is based on [ABINet](https://github.com/FangShancheng/ABINet.git).
+Thanks for their great works.
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2876b43f068a83aea4f308140a773d70f72129c3
--- /dev/null
+++ b/app.py
@@ -0,0 +1,30 @@
+import glob
+import gradio as gr
+from demo import get_model, preprocess, postprocess, load
+from utils import Config, Logger, CharsetMapper
+
+config = Config('configs/train_iternet.yaml')
+config.model_vision_checkpoint = None
+model = get_model(config)
+model = load(model, 'workdir/train-iternet/best-train-iternet.pth')
+charset = CharsetMapper(filename=config.dataset_charset_path, max_length=config.dataset_max_length + 1)
+
+def process_image(image):
+ img = image.convert('RGB')
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
+ res = model(img)
+ return postprocess(res, charset, 'alignment')[0][0]
+
+title = "Interactive demo: ABINet"
+description = "Demo for ABINet, ABINet uses a vision model and an explicit language model to recognize text in the wild, which are trained in end-to-end way. The language model (BCN) achieves bidirectional language representation in simulating cloze test, additionally utilizing iterative correction strategy. To use it, simply upload a (single-text line) image or use one of the example images below and click 'submit'. Results will show up in a few seconds."
+article = "
Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition | Github Repo
" + +iface = gr.Interface(fn=process_image, + inputs=gr.inputs.Image(type="pil"), + outputs=gr.outputs.Textbox(), + title=title, + description=description, + article=article, + examples=glob.glob('figs/test/*.png')) + +iface.launch(debug=True, share=True,enable_queue=True) \ No newline at end of file diff --git a/callbacks.py b/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..82fb9e34da2a819ce849857c304bb3cd23973e81 --- /dev/null +++ b/callbacks.py @@ -0,0 +1,360 @@ +import logging +import shutil +import time + +import editdistance as ed +import torchvision.utils as vutils +from fastai.callbacks.tensorboard import (LearnerTensorboardWriter, + SummaryWriter, TBWriteRequest, + asyncTBWriter) +from fastai.vision import * +from torch.nn.parallel import DistributedDataParallel +from torchvision import transforms + +import dataset +from utils import CharsetMapper, Timer, blend_mask + + +class IterationCallback(LearnerTensorboardWriter): + "A `TrackerCallback` that monitor in each iteration." + def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5, + show_iters:int=50, eval_iters:int=1000, save_iters:int=20000, + start_iters:int=0, stats_iters=20000): + #if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files + super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters, + stats_iters=stats_iters, hist_iters=stats_iters) + self.name, self.bestname = Path(name).name, f'best-{Path(name).name}' + self.show_iters = show_iters + self.eval_iters = eval_iters + self.save_iters = save_iters + self.start_iters = start_iters + self.checpoint_keep_num = checpoint_keep_num + self.metrics_root = 'metrics/' # rewrite + self.timer = Timer() + self.host = self.learn.rank is None or self.learn.rank == 0 + + def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None: + "Writes training metrics to Tensorboard." + for i, name in enumerate(names): + if last_metrics is None or len(last_metrics) < i+1: return + scalar_value = last_metrics[i] + self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration) + + def _write_sub_loss(self, iteration:int, last_losses:dict)->None: + "Writes sub loss to Tensorboard." + for name, loss in last_losses.items(): + scalar_value = to_np(loss) + tag = self.metrics_root + name + self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration) + + def _save(self, name): + if isinstance(self.learn.model, DistributedDataParallel): + tmp = self.learn.model + self.learn.model = self.learn.model.module + self.learn.save(name) + self.learn.model = tmp + else: self.learn.save(name) + + def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False): + "Validate on `dl` with potential `callbacks` and `metrics`." + dl = ifnone(dl, self.learn.data.valid_dl) + metrics = ifnone(metrics, self.learn.metrics) + cb_handler = CallbackHandler(ifnone(callbacks, []), metrics) + cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin() + if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[])) + val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler) + cb_handler.on_epoch_end(val_metrics) + if keeped_items: return cb_handler.state_dict['keeped_items'] + else: return cb_handler.state_dict['last_metrics'] + + def jump_to_epoch_iter(self, epoch:int, iteration:int)->None: + try: + self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False) + logging.info(f'Loaded {self.name}_{epoch}_{iteration}') + except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.') + + def on_train_begin(self, n_epochs, **kwargs): + # TODO: can not write graph here + # super().on_train_begin(**kwargs) + self.best = -float('inf') + self.timer.tic() + if self.host: + checkpoint_path = self.learn.path/'checkpoint.yaml' + if checkpoint_path.exists(): + os.remove(checkpoint_path) + open(checkpoint_path, 'w').close() + return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate + + def on_batch_begin(self, **kwargs:Any)->None: + self.timer.toc_data() + super().on_batch_begin(**kwargs) + + def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs): + super().on_batch_end(last_loss, iteration, train, **kwargs) + if iteration == 0: return + + if iteration % self.loss_iters == 0: + last_losses = self.learn.loss_func.last_losses + self._write_sub_loss(iteration=iteration, last_losses=last_losses) + self.tbwriter.add_scalar(tag=self.metrics_root + 'lr', + scalar_value=self.opt.lr, global_step=iteration) + + if iteration % self.show_iters == 0: + log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \ + f'smooth loss = {smooth_loss:6.4f}' + logging.info(log_str) + # log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s' + # logging.info(log_str) + + if iteration % self.eval_iters == 0: + # TODO: or remove time to on_epoch_end + # 1. Record time + log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \ + f'average running time = {self.timer.average_running_time():.4f}s' + logging.info(log_str) + + # 2. Call validate + last_metrics = self._validate() + self.learn.model.train() + log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \ + f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \ + f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \ + f'ted/w = {last_metrics[5]:6.4f}, ' + logging.info(log_str) + names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w'] + self._write_metrics(iteration, names, last_metrics) + + # 3. Save best model + current = last_metrics[2] + if current is not None and current > self.best: + logging.info(f'Better model found at epoch {epoch}, '\ + f'iter {iteration} with accuracy value: {current:6.4f}.') + self.best = current + self._save(f'{self.bestname}') + + if iteration % self.save_iters == 0 and self.host: + logging.info(f'Save model {self.name}_{epoch}_{iteration}') + filename = f'{self.name}_{epoch}_{iteration}' + self._save(filename) + + checkpoint_path = self.learn.path/'checkpoint.yaml' + if not checkpoint_path.exists(): + open(checkpoint_path, 'w').close() + with open(checkpoint_path, 'r') as file: + checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict() + checkpoints['all_checkpoints'] = ( + checkpoints.get('all_checkpoints') or list()) + checkpoints['all_checkpoints'].insert(0, filename) + if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num: + removed_checkpoint = checkpoints['all_checkpoints'].pop() + removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth' + os.remove(removed_checkpoint) + checkpoints['current_checkpoint'] = filename + with open(checkpoint_path, 'w') as file: + yaml.dump(checkpoints, file) + + + self.timer.toc_running() + + def on_train_end(self, **kwargs): + #self.learn.load(f'{self.bestname}', purge=False) + pass + + def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None: + self._write_embedding(iteration=iteration) + + +class TextAccuracy(Callback): + _names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w'] + def __init__(self, charset_path, max_length, case_sensitive, model_eval): + self.charset_path = charset_path + self.max_length = max_length + self.case_sensitive = case_sensitive + self.charset = CharsetMapper(charset_path, self.max_length) + self.names = self._names + + self.model_eval = model_eval or 'alignment' + assert self.model_eval in ['vision', 'language', 'alignment'] + + def on_epoch_begin(self, **kwargs): + self.total_num_char = 0. + self.total_num_word = 0. + self.correct_num_char = 0. + self.correct_num_word = 0. + self.total_ed = 0. + self.total_ned = 0. + + def _get_output(self, last_output): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == self.model_eval: output = res + else: output = last_output + return output + + def _update_output(self, last_output, items): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == self.model_eval: res.update(items) + else: last_output.update(items) + return last_output + + def on_batch_end(self, last_output, last_target, **kwargs): + output = self._get_output(last_output) + logits, pt_lengths = output['logits'], output['pt_lengths'] + pt_text, pt_scores, pt_lengths_ = self.decode(logits) + assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}' + last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores}) + + pt_text = [self.charset.trim(t) for t in pt_text] + label = last_target[0] + if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label + gt_text = [self.charset.get_text(l, trim=True) for l in label] + + for i in range(len(gt_text)): + if not self.case_sensitive: + gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower() + distance = ed.eval(gt_text[i], pt_text[i]) + self.total_ed += distance + self.total_ned += float(distance) / max(len(gt_text[i]), 1) + + if gt_text[i] == pt_text[i]: + self.correct_num_word += 1 + self.total_num_word += 1 + + for j in range(min(len(gt_text[i]), len(pt_text[i]))): + if gt_text[i][j] == pt_text[i][j]: + self.correct_num_char += 1 + self.total_num_char += len(gt_text[i]) + + return {'last_output': last_output} + + def on_epoch_end(self, last_metrics, **kwargs): + mets = [self.correct_num_char / self.total_num_char, + self.correct_num_word / self.total_num_word, + self.total_ed, + self.total_ned, + self.total_ed / self.total_num_word] + return add_metrics(last_metrics, mets) + + def decode(self, logit): + """ Greed decode """ + # TODO: test running time and decode on GPU + out = F.softmax(logit, dim=2) + pt_text, pt_scores, pt_lengths = [], [], [] + for o in out: + text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False) + text = text.split(self.charset.null_char)[0] # end at end-token + pt_text.append(text) + pt_scores.append(o.max(dim=1)[0]) + pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token + pt_scores = torch.stack(pt_scores) + pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long) + return pt_text, pt_scores, pt_lengths + + +class TopKTextAccuracy(TextAccuracy): + _names = ['ccr', 'cwr'] + def __init__(self, k, charset_path, max_length, case_sensitive, model_eval): + self.k = k + self.charset_path = charset_path + self.max_length = max_length + self.case_sensitive = case_sensitive + self.charset = CharsetMapper(charset_path, self.max_length) + self.names = self._names + + def on_epoch_begin(self, **kwargs): + self.total_num_char = 0. + self.total_num_word = 0. + self.correct_num_char = 0. + self.correct_num_word = 0. + + def on_batch_end(self, last_output, last_target, **kwargs): + logits, pt_lengths = last_output['logits'], last_output['pt_lengths'] + gt_labels, gt_lengths = last_target[:] + + for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths): + word_flag = True + for i in range(length): + char_logit = logit[i].topk(self.k)[1] + char_label = label[i].argmax(-1) + if char_label in char_logit: self.correct_num_char += 1 + else: word_flag = False + self.total_num_char += 1 + if pt_length == length and word_flag: + self.correct_num_word += 1 + self.total_num_word += 1 + + def on_epoch_end(self, last_metrics, **kwargs): + mets = [self.correct_num_char / self.total_num_char, + self.correct_num_word / self.total_num_word, + 0., 0., 0.] + return add_metrics(last_metrics, mets) + + +class DumpPrediction(LearnerCallback): + + def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False): + super().__init__(learn=learn) + self.debug = debug + self.model_eval = model_eval or 'alignment' + self.image_only = image_only + assert self.model_eval in ['vision', 'language', 'alignment'] + + self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}' + self.attn_root = self.root/'attn' + self.charset = CharsetMapper(charset_path) + if self.root.exists(): shutil.rmtree(self.root) + self.root.mkdir(), self.attn_root.mkdir() + + self.pil = transforms.ToPILImage() + self.tensor = transforms.ToTensor() + size = self.learn.data.img_h, self.learn.data.img_w + self.resize = transforms.Resize(size=size, interpolation=0) + self.c = 0 + + def on_batch_end(self, last_input, last_output, last_target, **kwargs): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == self.model_eval: pt_text = res['pt_text'] + if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu() + if res['name'] == self.model_eval: logits = res['logits'] + else: + pt_text = last_output['pt_text'] + attn_scores = last_output['attn_scores'].detach().cpu() + logits = last_output['logits'] + + images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input + images = images.detach().cpu() + pt_text = [self.charset.trim(t) for t in pt_text] + gt_label = last_target[0] + if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label + gt_text = [self.charset.get_text(l, trim=True) for l in gt_label] + + prediction, false_prediction = [], [] + for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits): + prediction.append(f'{gt}\t{pt}\n') + if gt != pt: + if self.debug: + scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1] + logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}') + false_prediction.append(f'{gt}\t{pt}\n') + + image = self.learn.data.denorm(image) + if not self.image_only: + image_np = np.array(self.pil(image)) + attn_pil = [self.pil(a) for a in attn[:, None, :, :]] + attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil] + attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0) + blended_sum = self.tensor(blend_mask(image_np, attn_sum)) + blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil] + save_image = torch.stack([image] + attn + [blended_sum] + blended) + save_image = save_image.view(2, -1, *save_image.shape[1:]) + save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1) + vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg', + nrow=2, normalize=True, scale_each=True) + else: + self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg') + self.c += 1 + + with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction) + with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction) diff --git a/configs/pretrain_itervm.yaml b/configs/pretrain_itervm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b31f92c57b8ce0defc97e26ecc97e3bde2f16220 --- /dev/null +++ b/configs/pretrain_itervm.yaml @@ -0,0 +1,60 @@ +global: + name: pretrain-itervm + phase: train + stage: pretrain-vision + workdir: workdir + seed: ~ + +dataset: + train: { + roots: ['data/training/MJ/MJ_train/', + 'data/training/MJ/MJ_test/', + 'data/training/MJ/MJ_valid/', + 'data/training/ST'], + batch_size: 384 + } + test: { + roots: ['data/evaluation/IIIT5k_3000', + 'data/evaluation/SVT', + 'data/evaluation/SVTP', + 'data/evaluation/IC13_857', + 'data/evaluation/IC15_1811', + 'data/evaluation/CUTE80'], + batch_size: 384 + } + data_aug: True + multiscales: False + num_workers: 14 + +training: + epochs: 8 + show_iters: 50 + eval_iters: 3000 + save_iters: 3000 + +optimizer: + type: Adam + true_wd: False + wd: 0.0 + bn_wd: False + clip_grad: 20 + lr: 0.0001 + args: { + betas: !!python/tuple [0.9, 0.999], # for default Adam + } + scheduler: { + periods: [6, 2], + gamma: 0.1, + } + +model: + name: 'modules.model_vision.BaseIterVision' + checkpoint: ~ + vision: { + loss_weight: 1., + attention: 'position', + backbone: 'transformer', + backbone_ln: 3, + iter_size: 3, + backbone_alpha_d: 0.5, + } diff --git a/configs/pretrain_language_model.yaml b/configs/pretrain_language_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7238f8698835850dda362d0ebcf663f124a6ab7a --- /dev/null +++ b/configs/pretrain_language_model.yaml @@ -0,0 +1,45 @@ +global: + name: pretrain-language-model + phase: train + stage: pretrain-language + workdir: workdir + seed: ~ + +dataset: + train: { + roots: ['data/WikiText-103.csv'], + batch_size: 4096 + } + test: { + roots: ['data/WikiText-103_eval_d1.csv'], + batch_size: 4096 + } + +training: + epochs: 80 + show_iters: 50 + eval_iters: 6000 + save_iters: 3000 + +optimizer: + type: Adam + true_wd: False + wd: 0.0 + bn_wd: False + clip_grad: 20 + lr: 0.0001 + args: { + betas: !!python/tuple [0.9, 0.999], # for default Adam + } + scheduler: { + periods: [70, 10], + gamma: 0.1, + } + +model: + name: 'modules.model_language.BCNLanguage' + language: { + num_layers: 4, + loss_weight: 1., + use_self_attn: False + } diff --git a/configs/pretrain_vm.yaml b/configs/pretrain_vm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a686f80cb4e9bbcc2b0824fecc876ed20c05e9af --- /dev/null +++ b/configs/pretrain_vm.yaml @@ -0,0 +1,51 @@ +global: + name: pretrain-vm + phase: train + stage: pretrain-vision + workdir: workdir + seed: ~ + +dataset: + train: { + roots: ['output_tbell_dataset/'], + batch_size: 20 + } + test: { + roots: ['output_tbell_dataset/'], + batch_size: 20 + } + data_aug: True + multiscales: False + num_workers: 1 + +training: + epochs: 8 + show_iters: 50 + eval_iters: 50 + # save_iters: 3000 + +optimizer: + type: Adam + true_wd: False + wd: 0.0 + bn_wd: False + clip_grad: 20 + lr: 0.0001 + args: { + betas: !!python/tuple [0.9, 0.999], # for default Adam + } + scheduler: { + periods: [6, 2], + gamma: 0.1, + } + +model: + name: 'modules.model_vision.BaseVision' + checkpoint: ~ + vision: { + loss_weight: 1., + attention: 'position', + backbone: 'transformer', + backbone_ln: 3, + backbone_alpha_d: 0.5, + } diff --git a/configs/template.yaml b/configs/template.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1f997737702fd81d1420b2a497c81d9700ec8af --- /dev/null +++ b/configs/template.yaml @@ -0,0 +1,67 @@ +global: + name: exp + phase: train + stage: pretrain-vision + workdir: /tmp/workdir + seed: ~ + +dataset: + train: { + roots: ['data/training/MJ/MJ_train/', + 'data/training/MJ/MJ_test/', + 'data/training/MJ/MJ_valid/', + 'data/training/ST'], + batch_size: 128 + } + test: { + roots: ['data/evaluation/IIIT5k_3000', + 'data/evaluation/SVT', + 'data/evaluation/SVTP', + 'data/evaluation/IC13_857', + 'data/evaluation/IC15_1811', + 'data/evaluation/CUTE80'], + batch_size: 128 + } + charset_path: data/charset_36.txt + num_workers: 4 + max_length: 25 # 30 + image_height: 32 + image_width: 128 + case_sensitive: False + eval_case_sensitive: False + data_aug: True + multiscales: False + pin_memory: True + smooth_label: False + smooth_factor: 0.1 + one_hot_y: True + use_sm: False + +training: + epochs: 6 + show_iters: 50 + eval_iters: 3000 + save_iters: 20000 + start_iters: 0 + stats_iters: 100000 + +optimizer: + type: Adadelta # Adadelta, Adam + true_wd: False + wd: 0. # 0.001 + bn_wd: False + args: { + # betas: !!python/tuple [0.9, 0.99], # betas=(0.9,0.99) for AdamW + # betas: !!python/tuple [0.9, 0.999], # for default Adam + } + clip_grad: 20 + lr: [1.0, 1.0, 1.0] # lr: [0.005, 0.005, 0.005] + scheduler: { + periods: [3, 2, 1], + gamma: 0.1, + } + +model: + name: 'modules.model_iternet.IterNet' + checkpoint: ~ + strict: True \ No newline at end of file diff --git a/configs/train_iternet.yaml b/configs/train_iternet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6747eb4419ea11977479b9adabf1ae0aeaff472 --- /dev/null +++ b/configs/train_iternet.yaml @@ -0,0 +1,65 @@ +global: + name: train-iternet + phase: train + stage: train-super + workdir: workdir + seed: ~ + +dataset: + train: { + roots: ['output_pixelplanet_dataset/'], + batch_size: 20 + } + test: { + roots: ['output_pixelplanet_dataset/'], + batch_size: 20 + } + data_aug: True + multiscales: False + num_workers: 8 + +training: + epochs: 1000 + show_iters: 500 + eval_iters: 500 + # save_iters: 1 + +optimizer: + type: Adam + true_wd: False + wd: 0.0 + bn_wd: False + clip_grad: 20 + lr: 0.0001 + args: { + betas: !!python/tuple [0.9, 0.999], # for default Adam + } + scheduler: { + periods: [6, 4], + gamma: 0.1, + } + +model: + name: 'modules.model_iternet.IterNet' + iter_size: 3 + ensemble: '' + use_vision: False + vision: { + checkpoint: workdir/train-iternet/best-train-iternet.pth, + loss_weight: 1., + attention: 'position', + backbone: 'transformer', + backbone_ln: 3, + iter_size: 3, + backbone_alpha_d: 0.5, + } + # language: { + # checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth, + # num_layers: 4, + # loss_weight: 1., + # detach: True, + # use_self_attn: False + # } + alignment: { + loss_weight: 1., + } diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e424cb2134ba0d992515b2446302e1a758a3db66 --- /dev/null +++ b/dataset.py @@ -0,0 +1,278 @@ +import logging +import re + +import cv2 +import lmdb +import six +from fastai.vision import * +from torchvision import transforms + +from transforms import CVColorJitter, CVDeterioration, CVGeometry +from utils import CharsetMapper, onehot + + +class ImageDataset(Dataset): + "`ImageDataset` read data from LMDB database." + + def __init__(self, + path:PathOrStr, + is_training:bool=True, + img_h:int=32, + img_w:int=100, + max_length:int=25, + check_length:bool=True, + case_sensitive:bool=False, + charset_path:str='data/charset_36.txt', + convert_mode:str='RGB', + data_aug:bool=True, + deteriorate_ratio:float=0., + multiscales:bool=True, + one_hot_y:bool=True, + return_idx:bool=False, + return_raw:bool=False, + **kwargs): + self.path, self.name = Path(path), Path(path).name + assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory." + self.convert_mode, self.check_length = convert_mode, check_length + self.img_h, self.img_w = img_h, img_w + self.max_length, self.one_hot_y = max_length, one_hot_y + self.return_idx, self.return_raw = return_idx, return_raw + self.case_sensitive, self.is_training = case_sensitive, is_training + self.data_aug, self.multiscales = data_aug, multiscales + self.charset = CharsetMapper(charset_path, max_length=max_length+1) + self.c = self.charset.num_classes + + self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False) + assert self.env, f'Cannot open LMDB dataset from {path}.' + with self.env.begin(write=False) as txn: + self.length = int(txn.get('num-samples'.encode())) + + if self.is_training and self.data_aug: + self.augment_tfs = transforms.Compose([ + CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5), + CVDeterioration(var=20, degrees=6, factor=4, p=0.25), + CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25) + ]) + self.totensor = transforms.ToTensor() + + def __len__(self): return self.length + + def _next_image(self, index): + next_index = random.randint(0, len(self) - 1) + return self.get(next_index) + + def _check_image(self, x, pixels=6): + if x.size[0] <= pixels or x.size[1] <= pixels: return False + else: return True + + def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): + def _resize_ratio(img, ratio, fix_h=True): + if ratio * self.img_w < self.img_h: + if fix_h: trg_h = self.img_h + else: trg_h = int(ratio * self.img_w) + trg_w = self.img_w + else: trg_h, trg_w = self.img_h, int(self.img_h / ratio) + img = cv2.resize(img, (trg_w, trg_h)) + pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2 + top, bottom = math.ceil(pad_h), math.floor(pad_h) + left, right = math.ceil(pad_w), math.floor(pad_w) + img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType) + return img + + if self.is_training: + if random.random() < 0.5: + base, maxh, maxw = self.img_h, self.img_h, self.img_w + h, w = random.randint(base, maxh), random.randint(base, maxw) + return _resize_ratio(img, h/w) + else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio + else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio + + def resize(self, img): + if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE) + else: return cv2.resize(img, (self.img_w, self.img_h)) + + def get(self, idx): + with self.env.begin(write=False) as txn: + image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}' + try: + label = str(txn.get(label_key.encode()), 'utf-8') # label + label = re.sub('[^0-9a-zA-Z]+', '', label) + if self.check_length and self.max_length > 0: + if len(label) > self.max_length or len(label) <= 0: + #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + label = label[:self.max_length] + + imgbuf = txn.get(image_key.encode()) # image + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin + image = PIL.Image.open(buf).convert(self.convert_mode) + if self.is_training and not self._check_image(image): + #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + except: + import traceback + traceback.print_exc() + logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}') + return self._next_image(idx) + return image, label, idx + + def _process_training(self, image): + if self.data_aug: image = self.augment_tfs(image) + image = self.resize(np.array(image)) + return image + + def _process_test(self, image): + return self.resize(np.array(image)) # TODO:move is_training to here + + def __getitem__(self, idx): + image, text, idx_new = self.get(idx) + if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.' + + if self.is_training: image = self._process_training(image) + else: image = self._process_test(image) + if self.return_raw: return image, text + image = self.totensor(image) + + length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token + label = self.charset.get_labels(text, case_sensitive=self.case_sensitive) + label = tensor(label).to(dtype=torch.long) + if self.one_hot_y: label = onehot(label, self.charset.num_classes) + + if self.return_idx: y = [label, length, idx_new] + else: y = [label, length] + return image, y + + +class TextDataset(Dataset): + def __init__(self, + path:PathOrStr, + delimiter:str='\t', + max_length:int=25, + charset_path:str='data/charset_36.txt', + case_sensitive=False, + one_hot_x=True, + one_hot_y=True, + is_training=True, + smooth_label=False, + smooth_factor=0.2, + use_sm=False, + **kwargs): + self.path = Path(path) + self.case_sensitive, self.use_sm = case_sensitive, use_sm + self.smooth_factor, self.smooth_label = smooth_factor, smooth_label + self.charset = CharsetMapper(charset_path, max_length=max_length+1) + self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training + if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset) + + dtype = {'inp': str, 'gt': str} + self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False) + self.inp_col, self.gt_col = 0, 1 + + def __len__(self): return len(self.df) + + def __getitem__(self, idx): + text_x = self.df.iloc[idx, self.inp_col] + text_x = re.sub('[^0-9a-zA-Z]+', '', text_x) + if not self.case_sensitive: text_x = text_x.lower() + if self.is_training and self.use_sm: text_x = self.sm(text_x) + + length_x = tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token + label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive) + label_x = tensor(label_x) + if self.one_hot_x: + label_x = onehot(label_x, self.charset.num_classes) + if self.is_training and self.smooth_label: + label_x = torch.stack([self.prob_smooth_label(l) for l in label_x]) + x = [label_x, length_x] + + text_y = self.df.iloc[idx, self.gt_col] + text_y = re.sub('[^0-9a-zA-Z]+', '', text_y) + if not self.case_sensitive: text_y = text_y.lower() + length_y = tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token + label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive) + label_y = tensor(label_y) + if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes) + y = [label_y, length_y] + + return x, y + + def prob_smooth_label(self, one_hot): + one_hot = one_hot.float() + delta = torch.rand([]) * self.smooth_factor + num_classes = len(one_hot) + noise = torch.rand(num_classes) + noise = noise / noise.sum() * delta + one_hot = one_hot * (1 - delta) + noise + return one_hot + + +class SpellingMutation(object): + def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None): + """ + Args: + pn0: the prob of not modifying characters is (pn0) + pn1: the prob of modifying one characters is (pn1 - pn0) + pn2: the prob of modifying two characters is (pn2 - pn1), + and three (1 - pn2) + pt0: the prob of replacing operation is pt0. + pt1: the prob of inserting operation is (pt1 - pt0), + and deleting operation is (1 - pt1) + """ + super().__init__() + self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2 + self.pt0, self.pt1 = pt0, pt1 + self.charset = charset + logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' + + f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}') + + def is_digit(self, text, ratio=0.5): + length = max(len(text), 1) + digit_num = sum([t in self.charset.digits for t in text]) + if digit_num / length < ratio: return False + return True + + def is_unk_char(self, char): + # return char == self.charset.unk_char + return (char not in self.charset.digits) and (char not in self.charset.alphabets) + + def get_num_to_modify(self, length): + prob = random.random() + if prob < self.pn0: num_to_modify = 0 + elif prob < self.pn1: num_to_modify = 1 + elif prob < self.pn2: num_to_modify = 2 + else: num_to_modify = 3 + + if length <= 1: num_to_modify = 0 + elif length >= 2 and length <= 4: num_to_modify = min(num_to_modify, 1) + else: num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2 + return num_to_modify + + def __call__(self, text, debug=False): + if self.is_digit(text): return text + length = len(text) + num_to_modify = self.get_num_to_modify(length) + if num_to_modify <= 0: return text + + chars = [] + index = np.arange(0, length) + random.shuffle(index) + index = index[: num_to_modify] + if debug: self.index = index + for i, t in enumerate(text): + if i not in index: chars.append(t) + elif self.is_unk_char(t): chars.append(t) + else: + prob = random.random() + if prob < self.pt0: # replace + chars.append(random.choice(self.charset.alphabets)) + elif prob < self.pt1: # insert + chars.append(random.choice(self.charset.alphabets)) + chars.append(t) + else: # delete + continue + new_text = ''.join(chars[: self.charset.max_length-1]) + return new_text if len(new_text) >= 1 else text \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe941bac7b3cc17fa4d390f3c878b2698e70b67 --- /dev/null +++ b/demo.py @@ -0,0 +1,109 @@ +import argparse +import logging +import os +import glob +import tqdm +import torch +import PIL +import cv2 +import numpy as np +import torch.nn.functional as F +from torchvision import transforms +from utils import Config, Logger, CharsetMapper + +def get_model(config): + import importlib + names = config.model_name.split('.') + module_name, class_name = '.'.join(names[:-1]), names[-1] + cls = getattr(importlib.import_module(module_name), class_name) + model = cls(config) + logging.info(model) + model = model.eval() + return model + +def preprocess(img, width, height): + img = cv2.resize(np.array(img), (width, height)) + img = transforms.ToTensor()(img).unsqueeze(0) + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + return (img-mean[...,None,None]) / std[...,None,None] + +def postprocess(output, charset, model_eval): + def _get_output(last_output, model_eval): + if isinstance(last_output, (tuple, list)): + for res in last_output: + if res['name'] == model_eval: output = res + else: output = last_output + return output + + def _decode(logit): + """ Greed decode """ + out = F.softmax(logit, dim=2) + pt_text, pt_scores, pt_lengths = [], [], [] + for o in out: + text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) + text = text.split(charset.null_char)[0] # end at end-token + pt_text.append(text) + pt_scores.append(o.max(dim=1)[0]) + pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token + return pt_text, pt_scores, pt_lengths + + output = _get_output(output, model_eval) + logits, pt_lengths = output['logits'], output['pt_lengths'] + pt_text, pt_scores, pt_lengths_ = _decode(logits) + + return pt_text, pt_scores, pt_lengths_ + +def load(model, file, device=None, strict=True): + if device is None: device = 'cpu' + elif isinstance(device, int): device = torch.device('cuda', device) + assert os.path.isfile(file) + state = torch.load(file, map_location=device) + if set(state.keys()) == {'model', 'opt'}: + state = state['model'] + model.load_state_dict(state, strict=strict) + return model + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='configs/train_iternet.yaml', + help='path to config file') + parser.add_argument('--input', type=str, default='figures/demo') + parser.add_argument('--cuda', type=int, default=-1) + parser.add_argument('--checkpoint', type=str, default='workdir/train-iternet/best-train-iternet.pth') + parser.add_argument('--model_eval', type=str, default='alignment', + choices=['alignment', 'vision', 'language']) + args = parser.parse_args() + config = Config(args.config) + if args.checkpoint is not None: config.model_checkpoint = args.checkpoint + if args.model_eval is not None: config.model_eval = args.model_eval + config.global_phase = 'test' + config.model_vision_checkpoint, config.model_language_checkpoint = None, None + device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' + + Logger.init(config.global_workdir, config.global_name, config.global_phase) + Logger.enable_file() + logging.info(config) + + logging.info('Construct model.') + model = get_model(config).to(device) + model = load(model, config.model_checkpoint, device=device) + charset = CharsetMapper(filename=config.dataset_charset_path, + max_length=config.dataset_max_length + 1) + + if os.path.isdir(args.input): + paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] + else: + paths = glob.glob(os.path.expanduser(args.input)) + assert paths, "The input path(s) was not found" + paths = sorted(paths) + for path in tqdm.tqdm(paths): + img = PIL.Image.open(path).convert('RGB') + img = preprocess(img, config.dataset_image_width, config.dataset_image_height) + img = img.to(device) + res = model(img) + pt_text, _, __ = postprocess(res, charset, config.model_eval) + logging.info(f'{path}: {pt_text[0]}') + +if __name__ == '__main__': + main() diff --git a/figures/demo/096314.png b/figures/demo/096314.png new file mode 100644 index 0000000000000000000000000000000000000000..10a55eab2418ebe44f12a04f5e808614b7d22081 Binary files /dev/null and b/figures/demo/096314.png differ diff --git a/figures/demo/096314_.png b/figures/demo/096314_.png new file mode 100644 index 0000000000000000000000000000000000000000..e5222b2f5066ab4cf58754720091ba03dda18921 Binary files /dev/null and b/figures/demo/096314_.png differ diff --git a/figures/demo/096314__.png b/figures/demo/096314__.png new file mode 100644 index 0000000000000000000000000000000000000000..d2f8accccf67135cc40b7ad362234f8b69e8b1eb Binary files /dev/null and b/figures/demo/096314__.png differ diff --git a/figures/demo/096314___.png b/figures/demo/096314___.png new file mode 100644 index 0000000000000000000000000000000000000000..99df3d57a7d6cad846a668ed809f8453a2b2fa50 Binary files /dev/null and b/figures/demo/096314___.png differ diff --git a/figures/demo/327806.png b/figures/demo/327806.png new file mode 100644 index 0000000000000000000000000000000000000000..e815bae919c463e233b81c4fdf5e71736bbac457 Binary files /dev/null and b/figures/demo/327806.png differ diff --git a/figures/demo/327806_.png b/figures/demo/327806_.png new file mode 100644 index 0000000000000000000000000000000000000000..650a9294f9765cbdcdf596622489d69ff76f830f Binary files /dev/null and b/figures/demo/327806_.png differ diff --git a/figures/demo/327806__.png b/figures/demo/327806__.png new file mode 100644 index 0000000000000000000000000000000000000000..9d2d34a2a07d07658ad331cc8dac28cf7c2a9dba Binary files /dev/null and b/figures/demo/327806__.png differ diff --git a/figures/demo/327806___.png b/figures/demo/327806___.png new file mode 100644 index 0000000000000000000000000000000000000000..2acf4a147934a799ade27bc29356f2225167f1e5 Binary files /dev/null and b/figures/demo/327806___.png differ diff --git a/figures/demo/365560__.png b/figures/demo/365560__.png new file mode 100644 index 0000000000000000000000000000000000000000..271eb9bac58b558f86be22179bde444790d50b2c Binary files /dev/null and b/figures/demo/365560__.png differ diff --git a/figures/demo/365560___.png b/figures/demo/365560___.png new file mode 100644 index 0000000000000000000000000000000000000000..4cc49b0e94ab40c1139471d7c84e4872a9149676 Binary files /dev/null and b/figures/demo/365560___.png differ diff --git a/figures/demo/418760__.png b/figures/demo/418760__.png new file mode 100644 index 0000000000000000000000000000000000000000..014191711ae2d57da55d079d72efe2bbda2e18b1 Binary files /dev/null and b/figures/demo/418760__.png differ diff --git a/figures/demo/418760___.png b/figures/demo/418760___.png new file mode 100644 index 0000000000000000000000000000000000000000..14ce56730126b4ce88e7b84219938932f4f0dd05 Binary files /dev/null and b/figures/demo/418760___.png differ diff --git a/figures/demo/444825.png b/figures/demo/444825.png new file mode 100644 index 0000000000000000000000000000000000000000..229bb0cb267d001c9756c46883d9868ebc099fee Binary files /dev/null and b/figures/demo/444825.png differ diff --git a/figures/demo/444825_.png b/figures/demo/444825_.png new file mode 100644 index 0000000000000000000000000000000000000000..b4beadee90864053eda30f0a80ab4922c09d861c Binary files /dev/null and b/figures/demo/444825_.png differ diff --git a/figures/demo/444825__.png b/figures/demo/444825__.png new file mode 100644 index 0000000000000000000000000000000000000000..9d08786bed43af3da60ca24cd8c5781320019c16 Binary files /dev/null and b/figures/demo/444825__.png differ diff --git a/figures/demo/444825___.png b/figures/demo/444825___.png new file mode 100644 index 0000000000000000000000000000000000000000..096e755f7fd8ca1012af7619d82e99dee44dbb7e Binary files /dev/null and b/figures/demo/444825___.png differ diff --git a/figures/demo/451010.png b/figures/demo/451010.png new file mode 100644 index 0000000000000000000000000000000000000000..3bbf31f3a0dfbc1d2efaf8ab8de30fa86c332f5d Binary files /dev/null and b/figures/demo/451010.png differ diff --git a/figures/demo/451010_.png b/figures/demo/451010_.png new file mode 100644 index 0000000000000000000000000000000000000000..2133be2d472f10f2bae445cab7c3bdcbea23ee28 Binary files /dev/null and b/figures/demo/451010_.png differ diff --git a/figures/demo/451010__.png b/figures/demo/451010__.png new file mode 100644 index 0000000000000000000000000000000000000000..e7890504af082b56f90659025f9890f8ad82c474 Binary files /dev/null and b/figures/demo/451010__.png differ diff --git a/figures/demo/451010___.png b/figures/demo/451010___.png new file mode 100644 index 0000000000000000000000000000000000000000..5c9bcc4b79bc645fd9921c3e62d017e56bb8270e Binary files /dev/null and b/figures/demo/451010___.png differ diff --git a/figures/demo/502850.png b/figures/demo/502850.png new file mode 100644 index 0000000000000000000000000000000000000000..ebbf94181632dd517a02b2c4708bcb6a36884630 Binary files /dev/null and b/figures/demo/502850.png differ diff --git a/figures/demo/502850_.png b/figures/demo/502850_.png new file mode 100644 index 0000000000000000000000000000000000000000..0370f762e6f19c85ab6f3cfeb1e6844d094764ff Binary files /dev/null and b/figures/demo/502850_.png differ diff --git a/figures/demo/502850__.png b/figures/demo/502850__.png new file mode 100644 index 0000000000000000000000000000000000000000..0812320005f09c4d69571e260b5d9aaf44e08c6e Binary files /dev/null and b/figures/demo/502850__.png differ diff --git a/figures/demo/502850___.png b/figures/demo/502850___.png new file mode 100644 index 0000000000000000000000000000000000000000..58555540f97c0ae3521c3e0d51b2d3d4f134e74a Binary files /dev/null and b/figures/demo/502850___.png differ diff --git a/figures/demo/534995.png b/figures/demo/534995.png new file mode 100644 index 0000000000000000000000000000000000000000..001bd952e9f2658ddabbd122651d7458fc69947b Binary files /dev/null and b/figures/demo/534995.png differ diff --git a/figures/demo/534995_.png b/figures/demo/534995_.png new file mode 100644 index 0000000000000000000000000000000000000000..8ccac172958adb66ca83577096fa88c106a919e9 Binary files /dev/null and b/figures/demo/534995_.png differ diff --git a/figures/demo/534995__.png b/figures/demo/534995__.png new file mode 100644 index 0000000000000000000000000000000000000000..013432051744dd4280d98ebe512cf7a61939960a Binary files /dev/null and b/figures/demo/534995__.png differ diff --git a/figures/demo/534995___.png b/figures/demo/534995___.png new file mode 100644 index 0000000000000000000000000000000000000000..da1527707c86b18db89bf5ac5944a04c4dd38087 Binary files /dev/null and b/figures/demo/534995___.png differ diff --git a/figures/demo/542200__.png b/figures/demo/542200__.png new file mode 100644 index 0000000000000000000000000000000000000000..1580719de948806e01400861b1c0682fb18cfd0c Binary files /dev/null and b/figures/demo/542200__.png differ diff --git a/figures/demo/542200___.png b/figures/demo/542200___.png new file mode 100644 index 0000000000000000000000000000000000000000..4609a78c05e66db019464defe81bfb4f2c0b349e Binary files /dev/null and b/figures/demo/542200___.png differ diff --git a/figures/demo/605564.png b/figures/demo/605564.png new file mode 100644 index 0000000000000000000000000000000000000000..0ba24dc0e03373952fd3cc860e390571c8521dad Binary files /dev/null and b/figures/demo/605564.png differ diff --git a/figures/demo/605564_.png b/figures/demo/605564_.png new file mode 100644 index 0000000000000000000000000000000000000000..2dea455b3135428c36fe9b5fcd97141ea2e10230 Binary files /dev/null and b/figures/demo/605564_.png differ diff --git a/figures/demo/605564__.png b/figures/demo/605564__.png new file mode 100644 index 0000000000000000000000000000000000000000..1c3a82d2d15d9a518093dcc1131f138cb385cd88 Binary files /dev/null and b/figures/demo/605564__.png differ diff --git a/figures/demo/605564___.png b/figures/demo/605564___.png new file mode 100644 index 0000000000000000000000000000000000000000..cafe9f4ceccfa2aa1a271581b3093947e78f6cc6 Binary files /dev/null and b/figures/demo/605564___.png differ diff --git a/figures/demo/614224.png b/figures/demo/614224.png new file mode 100644 index 0000000000000000000000000000000000000000..6ea836e0a2953a2c812913fd60d7923e51120b83 Binary files /dev/null and b/figures/demo/614224.png differ diff --git a/figures/demo/614224_.png b/figures/demo/614224_.png new file mode 100644 index 0000000000000000000000000000000000000000..98b31f03341d6d6c338fa664a030a727ef40f332 Binary files /dev/null and b/figures/demo/614224_.png differ diff --git a/figures/demo/614224__.png b/figures/demo/614224__.png new file mode 100644 index 0000000000000000000000000000000000000000..52c1a3e903a31d0c67c6b9eab56a93b65aafbb87 Binary files /dev/null and b/figures/demo/614224__.png differ diff --git a/figures/demo/614224___.png b/figures/demo/614224___.png new file mode 100644 index 0000000000000000000000000000000000000000..715495a61cff044221627c49aca4374be2eefcb1 Binary files /dev/null and b/figures/demo/614224___.png differ diff --git a/figures/demo/684224.png b/figures/demo/684224.png new file mode 100644 index 0000000000000000000000000000000000000000..0358d1e6ba1e214b2a25d8775e4cd48844c4e92f Binary files /dev/null and b/figures/demo/684224.png differ diff --git a/figures/demo/684224_.png b/figures/demo/684224_.png new file mode 100644 index 0000000000000000000000000000000000000000..b336d196899f401029d787f10196796fc0e20ce0 Binary files /dev/null and b/figures/demo/684224_.png differ diff --git a/figures/demo/684224__.png b/figures/demo/684224__.png new file mode 100644 index 0000000000000000000000000000000000000000..089dde10b9856b4690c52c0080c0dbb89d4832eb Binary files /dev/null and b/figures/demo/684224__.png differ diff --git a/figures/demo/769452.png b/figures/demo/769452.png new file mode 100644 index 0000000000000000000000000000000000000000..36aa677a86a29199ec7b319107f487401c07a4d9 Binary files /dev/null and b/figures/demo/769452.png differ diff --git a/figures/demo/769452_.png b/figures/demo/769452_.png new file mode 100644 index 0000000000000000000000000000000000000000..b57319de0864d6d08ca8d1a0f318675dd80ee996 Binary files /dev/null and b/figures/demo/769452_.png differ diff --git a/figures/demo/769452__.png b/figures/demo/769452__.png new file mode 100644 index 0000000000000000000000000000000000000000..422daddb367d8967415390daccb7f2224224b7e0 Binary files /dev/null and b/figures/demo/769452__.png differ diff --git a/figures/demo/769452___.png b/figures/demo/769452___.png new file mode 100644 index 0000000000000000000000000000000000000000..c5d71e44b8f3d1245ad924bc0d838e1b90b83a73 Binary files /dev/null and b/figures/demo/769452___.png differ diff --git a/figures/demo/802650.png b/figures/demo/802650.png new file mode 100644 index 0000000000000000000000000000000000000000..25fb2adc4a0869c3e8e6c8240a2821503ed843f2 Binary files /dev/null and b/figures/demo/802650.png differ diff --git a/figures/demo/802650_.png b/figures/demo/802650_.png new file mode 100644 index 0000000000000000000000000000000000000000..e319c2f4e9ef18436163e0964f3ecdf994367b45 Binary files /dev/null and b/figures/demo/802650_.png differ diff --git a/figures/demo/802650__.png b/figures/demo/802650__.png new file mode 100644 index 0000000000000000000000000000000000000000..dbff9e4cd6f223738986cbc9375c6665fd051b0e Binary files /dev/null and b/figures/demo/802650__.png differ diff --git a/figures/demo/802650___.png b/figures/demo/802650___.png new file mode 100644 index 0000000000000000000000000000000000000000..dc4950c64b68a95560b634ca398db7e8948c099a Binary files /dev/null and b/figures/demo/802650___.png differ diff --git a/figures/demo/826827___.png b/figures/demo/826827___.png new file mode 100644 index 0000000000000000000000000000000000000000..70a03788fcaefbb0ce85ba938e4737d3eaf8f311 Binary files /dev/null and b/figures/demo/826827___.png differ diff --git a/figures/demo/909414.png b/figures/demo/909414.png new file mode 100644 index 0000000000000000000000000000000000000000..fe559d37728237e08343729d04035ddd08858ed1 Binary files /dev/null and b/figures/demo/909414.png differ diff --git a/figures/demo/909414_.png b/figures/demo/909414_.png new file mode 100644 index 0000000000000000000000000000000000000000..47981c7c0ecd89740ebbcca240a58424e42e7396 Binary files /dev/null and b/figures/demo/909414_.png differ diff --git a/figures/demo/909414__.png b/figures/demo/909414__.png new file mode 100644 index 0000000000000000000000000000000000000000..36c3e527d8ffa726646bfd2284fbf868555538c9 Binary files /dev/null and b/figures/demo/909414__.png differ diff --git a/figures/demo/909414___.png b/figures/demo/909414___.png new file mode 100644 index 0000000000000000000000000000000000000000..93be23ef724688e3d953e3a6c92271d30989cbd9 Binary files /dev/null and b/figures/demo/909414___.png differ diff --git a/figures/demo/bB5TV5.png b/figures/demo/bB5TV5.png new file mode 100644 index 0000000000000000000000000000000000000000..0e47037edb39112eee33f92915716b411e7f6477 Binary files /dev/null and b/figures/demo/bB5TV5.png differ diff --git a/figures/demo/dxynb4.png b/figures/demo/dxynb4.png new file mode 100644 index 0000000000000000000000000000000000000000..341f7e5ea0a33dd9dd840986d4fb8b1dc882e24f Binary files /dev/null and b/figures/demo/dxynb4.png differ diff --git a/figures/demo/dxynb4_.png b/figures/demo/dxynb4_.png new file mode 100644 index 0000000000000000000000000000000000000000..21c68613f74af5ff37a7d244e78afc9bd80bb3d8 Binary files /dev/null and b/figures/demo/dxynb4_.png differ diff --git a/figures/demo/dxynb4__.png b/figures/demo/dxynb4__.png new file mode 100644 index 0000000000000000000000000000000000000000..eef989f91691c7aa91700cddc5c9d97ba37b7376 Binary files /dev/null and b/figures/demo/dxynb4__.png differ diff --git a/figures/demo/dxynb4___.png b/figures/demo/dxynb4___.png new file mode 100644 index 0000000000000000000000000000000000000000..9a745c631131a641da3a9c111e601b43787914c5 Binary files /dev/null and b/figures/demo/dxynb4___.png differ diff --git a/figures/framework.png b/figures/framework.png new file mode 100644 index 0000000000000000000000000000000000000000..5a953281c912b1343f8d4c6e8e9a2108830fa848 Binary files /dev/null and b/figures/framework.png differ diff --git a/losses.py b/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..eea99b5dc280b2e4719afe0b3bda0b3faf316327 --- /dev/null +++ b/losses.py @@ -0,0 +1,72 @@ +from fastai.vision import * + +from modules.model import Model + + +class MultiLosses(nn.Module): + def __init__(self, one_hot=True): + super().__init__() + self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss() + self.bce = torch.nn.BCELoss() + + @property + def last_losses(self): + return self.losses + + def _flatten(self, sources, lengths): + return torch.cat([t[:l] for t, l in zip(sources, lengths)]) + + def _merge_list(self, all_res): + if not isinstance(all_res, (list, tuple)): + return all_res + def merge(items): + if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0) + else: return items[0] + res = dict() + for key in all_res[0].keys(): + items = [r[key] for r in all_res] + res[key] = merge(items) + return res + + def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True): + loss_name = output.get('name') + pt_logits, weight = output['logits'], output['loss_weight'] + + assert pt_logits.shape[0] % gt_labels.shape[0] == 0 + iter_size = pt_logits.shape[0] // gt_labels.shape[0] + if iter_size > 1: + gt_labels = gt_labels.repeat(iter_size, 1, 1) + gt_lengths = gt_lengths.repeat(iter_size) + flat_gt_labels = self._flatten(gt_labels, gt_lengths) + flat_pt_logits = self._flatten(pt_logits, gt_lengths) + + nll = output.get('nll') + if nll is not None: + loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight + else: + loss = self.ce(flat_pt_logits, flat_gt_labels) * weight + if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss + + return loss + + def forward(self, outputs, *args): + self.losses = {} + if isinstance(outputs, (tuple, list)): + outputs = [self._merge_list(o) for o in outputs] + return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.]) + else: + return self._ce_loss(outputs, *args, record=False) + + +class SoftCrossEntropyLoss(nn.Module): + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + def forward(self, input, target, softmax=True): + if softmax: log_prob = F.log_softmax(input, dim=-1) + else: log_prob = torch.log(input) + loss = -(target * log_prob).sum(dim=-1) + if self.reduction == "mean": return loss.mean() + elif self.reduction == "sum": return loss.sum() + else: return loss diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..def41731ed4cbe77051e496caf2b2d37dd95611f --- /dev/null +++ b/main.py @@ -0,0 +1,246 @@ +import argparse +import logging +import os +import random + +import torch +from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase +from fastai.distributed import * +from fastai.vision import * +from torch.backends import cudnn + +from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy +from dataset import ImageDataset, TextDataset +from losses import MultiLosses +from utils import Config, Logger, MyDataParallel, MyConcatDataset + + +def _set_random_seed(seed): + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) + cudnn.deterministic = True + logging.warning('You have chosen to seed training. ' + 'This will slow down your training!') + +def _get_training_phases(config, n): + lr = np.array(config.optimizer_lr) + periods = config.optimizer_scheduler_periods + sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))] + phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i]) + for i in range(len(periods))] + return phases + +def _get_dataset(ds_type, paths, is_training, config, **kwargs): + kwargs.update({ + 'img_h': config.dataset_image_height, + 'img_w': config.dataset_image_width, + 'max_length': config.dataset_max_length, + 'case_sensitive': config.dataset_case_sensitive, + 'charset_path': config.dataset_charset_path, + 'data_aug': config.dataset_data_aug, + 'deteriorate_ratio': config.dataset_deteriorate_ratio, + 'is_training': is_training, + 'multiscales': config.dataset_multiscales, + 'one_hot_y': config.dataset_one_hot_y, + }) + datasets = [ds_type(p, **kwargs) for p in paths] + if len(datasets) > 1: return MyConcatDataset(datasets) + else: return datasets[0] + + +def _get_language_databaunch(config): + kwargs = { + 'max_length': config.dataset_max_length, + 'case_sensitive': config.dataset_case_sensitive, + 'charset_path': config.dataset_charset_path, + 'smooth_label': config.dataset_smooth_label, + 'smooth_factor': config.dataset_smooth_factor, + 'one_hot_y': config.dataset_one_hot_y, + 'use_sm': config.dataset_use_sm, + } + train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs) + valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs) + data = DataBunch.create( + path=train_ds.path, + train_ds=train_ds, + valid_ds=valid_ds, + bs=config.dataset_train_batch_size, + val_bs=config.dataset_test_batch_size, + num_workers=config.dataset_num_workers, + pin_memory=config.dataset_pin_memory) + logging.info(f'{len(data.train_ds)} training items found.') + if not data.empty_val: + logging.info(f'{len(data.valid_ds)} valid items found.') + return data + +def _get_databaunch(config): + # An awkward way to reduce loadding data time during test + if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots + train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config) + valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config) + data = ImageDataBunch.create( + train_ds=train_ds, + valid_ds=valid_ds, + bs=config.dataset_train_batch_size, + val_bs=config.dataset_test_batch_size, + num_workers=config.dataset_num_workers, + pin_memory=config.dataset_pin_memory).normalize(imagenet_stats) + ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd + data.add_tfm(ar_tfm) + + logging.info(f'{len(data.train_ds)} training items found.') + if not data.empty_val: + logging.info(f'{len(data.valid_ds)} valid items found.') + + return data + +def _get_model(config): + import importlib + names = config.model_name.split('.') + module_name, class_name = '.'.join(names[:-1]), names[-1] + cls = getattr(importlib.import_module(module_name), class_name) + model = cls(config) + logging.info(model) + return model + + +def _get_learner(config, data, model, local_rank=None): + strict = ifnone(config.model_strict, True) + if config.global_stage == 'pretrain-language': + metrics = [TopKTextAccuracy( + k=ifnone(config.model_k, 5), + charset_path=config.dataset_charset_path, + max_length=config.dataset_max_length + 1, + case_sensitive=config.dataset_eval_case_sensisitves, + model_eval=config.model_eval)] + else: + metrics = [TextAccuracy( + charset_path=config.dataset_charset_path, + max_length=config.dataset_max_length + 1, + case_sensitive=config.dataset_eval_case_sensisitves, + model_eval=config.model_eval)] + opt_type = getattr(torch.optim, config.optimizer_type) + learner = Learner(data, model, silent=True, model_dir='.', + true_wd=config.optimizer_true_wd, + wd=config.optimizer_wd, + bn_wd=config.optimizer_bn_wd, + path=config.global_workdir, + metrics=metrics, + opt_func=partial(opt_type, **config.optimizer_args or dict()), + loss_func=MultiLosses(one_hot=config.dataset_one_hot_y)) + learner.split(lambda m: children(m)) + + if config.global_phase == 'train': + num_replicas = 1 if local_rank is None else torch.distributed.get_world_size() + phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas) + learner.callback_fns += [ + partial(GeneralScheduler, phases=phases), + partial(GradientClipping, clip=config.optimizer_clip_grad), + partial(IterationCallback, name=config.global_name, + show_iters=config.training_show_iters, + eval_iters=config.training_eval_iters, + save_iters=config.training_save_iters, + start_iters=config.training_start_iters, + stats_iters=config.training_stats_iters)] + else: + learner.callbacks += [ + DumpPrediction(learn=learner, + dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path, + model_eval=config.model_eval, + debug=config.global_debug, + image_only=config.global_image_only)] + + learner.rank = local_rank + if local_rank is not None: + logging.info(f'Set model to distributed with rank {local_rank}.') + learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model) + learner.model.to(local_rank) + learner = learner.to_distributed(local_rank) + + if torch.cuda.device_count() > 1 and local_rank is None: + logging.info(f'Use {torch.cuda.device_count()} GPUs.') + learner.model = MyDataParallel(learner.model) + + if config.model_checkpoint: + if Path(config.model_checkpoint).exists(): + with open(config.model_checkpoint, 'rb') as f: + buffer = io.BytesIO(f.read()) + learner.load(buffer, strict=strict) + else: + from distutils.dir_util import copy_tree + src = Path('/data/fangsc/model')/config.global_name + trg = Path('/output')/config.global_name + if src.exists(): copy_tree(str(src), str(trg)) + learner.load(config.model_checkpoint, strict=strict) + logging.info(f'Read model from {config.model_checkpoint}') + elif config.global_phase == 'test': + learner.load(f'best-{config.global_name}', strict=strict) + logging.info(f'Read model from best-{config.global_name}') + + if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60 + learner.fit(epochs=0, lr=config.optimizer_lr) + learner.opt.mom = 0. + + return learner + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True, + help='path to config file') + parser.add_argument('--phase', type=str, default=None, choices=['train', 'test']) + parser.add_argument('--name', type=str, default=None) + parser.add_argument('--checkpoint', type=str, default=None) + parser.add_argument('--test_root', type=str, default=None) + parser.add_argument("--local_rank", type=int, default=None) + parser.add_argument('--debug', action='store_true', default=None) + parser.add_argument('--image_only', action='store_true', default=None) + parser.add_argument('--model_strict', action='store_false', default=None) + parser.add_argument('--model_eval', type=str, default=None, + choices=['alignment', 'vision', 'language']) + args = parser.parse_args() + config = Config(args.config) + if args.name is not None: config.global_name = args.name + if args.phase is not None: config.global_phase = args.phase + if args.test_root is not None: config.dataset_test_roots = [args.test_root] + if args.checkpoint is not None: config.model_checkpoint = args.checkpoint + if args.debug is not None: config.global_debug = args.debug + if args.image_only is not None: config.global_image_only = args.image_only + if args.model_eval is not None: config.model_eval = args.model_eval + if args.model_strict is not None: config.model_strict = args.model_strict + + Logger.init(config.global_workdir, config.global_name, config.global_phase) + Logger.enable_file() + _set_random_seed(config.global_seed) + logging.info(config) + + if args.local_rank is not None: + logging.info(f'Init distribution training at device {args.local_rank}.') + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + + logging.info('Construct dataset.') + if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config) + else: data = _get_databaunch(config) + + logging.info('Construct model.') + model = _get_model(config) + + logging.info('Construct learner.') + learner = _get_learner(config, data, model, args.local_rank) + + if config.global_phase == 'train': + logging.info('Start training.') + learner.fit(epochs=config.training_epochs, + lr=config.optimizer_lr) + else: + logging.info('Start validate') + last_metrics = learner.validate() + log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \ + f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \ + f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \ + f'ted/w = {last_metrics[5]:6.3f}, ' + logging.info(log_str) + +if __name__ == '__main__': + main() diff --git a/transforms.py b/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7042f3368bc832566d5c22d1e18abe5d8547f5 --- /dev/null +++ b/transforms.py @@ -0,0 +1,329 @@ +import math +import numbers +import random + +import cv2 +import numpy as np +from PIL import Image +from torchvision import transforms +from torchvision.transforms import Compose + + +def sample_asym(magnitude, size=None): + return np.random.beta(1, 4, size) * magnitude + +def sample_sym(magnitude, size=None): + return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude + +def sample_uniform(low, high, size=None): + return np.random.uniform(low, high, size=size) + +def get_interpolation(type='random'): + if type == 'random': + choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA] + interpolation = choice[random.randint(0, len(choice)-1)] + elif type == 'nearest': interpolation = cv2.INTER_NEAREST + elif type == 'linear': interpolation = cv2.INTER_LINEAR + elif type == 'cubic': interpolation = cv2.INTER_CUBIC + elif type == 'area': interpolation = cv2.INTER_AREA + else: raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!') + return interpolation + +class CVRandomRotation(object): + def __init__(self, degrees=15): + assert isinstance(degrees, numbers.Number), "degree should be a single number." + assert degrees >= 0, "degree must be positive." + self.degrees = degrees + + @staticmethod + def get_params(degrees): + return sample_sym(degrees) + + def __call__(self, img): + angle = self.get_params(self.degrees) + src_h, src_w = img.shape[:2] + M = cv2.getRotationMatrix2D(center=(src_w/2, src_h/2), angle=angle, scale=1.0) + abs_cos, abs_sin = abs(M[0,0]), abs(M[0,1]) + dst_w = int(src_h * abs_sin + src_w * abs_cos) + dst_h = int(src_h * abs_cos + src_w * abs_sin) + M[0, 2] += (dst_w - src_w)/2 + M[1, 2] += (dst_h - src_h)/2 + + flags = get_interpolation() + return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) + +class CVRandomAffine(object): + def __init__(self, degrees, translate=None, scale=None, shear=None): + assert isinstance(degrees, numbers.Number), "degree should be a single number." + assert degrees >= 0, "degree must be positive." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = [shear] + else: + assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear): + # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717 + from numpy import sin, cos, tan + + if isinstance(shear, numbers.Number): + shear = [shear, 0] + + if not isinstance(shear, (tuple, list)) and len(shear) == 2: + raise ValueError( + "Shear should be a single value or a tuple/list containing " + + "two values. Got {}".format(shear)) + + rot = math.radians(angle) + sx, sy = [math.radians(s) for s in shear] + + cx, cy = center + tx, ty = translate + + # RSS without scaling + a = cos(rot - sy) / cos(sy) + b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) + c = sin(rot - sy) / cos(sy) + d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + M = [d, -b, 0, + -c, a, 0] + M = [x / scale for x in M] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) + M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + M[2] += cx + M[5] += cy + return M + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, height): + angle = sample_sym(degrees) + if translate is not None: + max_dx = translate[0] * height + max_dy = translate[1] * height + translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = sample_uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + if len(shears) == 1: + shear = [sample_sym(shears[0]), 0.] + elif len(shears) == 2: + shear = [sample_sym(shears[0]), sample_sym(shears[1])] + else: + shear = 0.0 + + return angle, translations, scale, shear + + + def __call__(self, img): + src_h, src_w = img.shape[:2] + angle, translate, scale, shear = self.get_params( + self.degrees, self.translate, self.scale, self.shear, src_h) + + M = self._get_inverse_affine_matrix((src_w/2, src_h/2), angle, (0, 0), scale, shear) + M = np.array(M).reshape(2,3) + + startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)] + project = lambda x, y, a, b, c: int(a*x + b*y + c) + endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints] + + rect = cv2.minAreaRect(np.array(endpoints)) + bbox = cv2.boxPoints(rect).astype(dtype=np.int) + max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() + min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() + + dst_w = int(max_x - min_x) + dst_h = int(max_y - min_y) + M[0, 2] += (dst_w - src_w) / 2 + M[1, 2] += (dst_h - src_h) / 2 + + # add translate + dst_w += int(abs(translate[0])) + dst_h += int(abs(translate[1])) + if translate[0] < 0: M[0, 2] += abs(translate[0]) + if translate[1] < 0: M[1, 2] += abs(translate[1]) + + flags = get_interpolation() + return cv2.warpAffine(img, M, (dst_w , dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE) + +class CVRandomPerspective(object): + def __init__(self, distortion=0.5): + self.distortion = distortion + + def get_params(self, width, height, distortion): + offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int) + offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int) + topleft = ( offset_w[0], offset_h[0]) + topright = (width - 1 - offset_w[1], offset_h[1]) + botright = (width - 1 - offset_w[2], height - 1 - offset_h[2]) + botleft = ( offset_w[3], height - 1 - offset_h[3]) + + startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] + endpoints = [topleft, topright, botright, botleft] + return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32) + + def __call__(self, img): + height, width = img.shape[:2] + startpoints, endpoints = self.get_params(width, height, self.distortion) + M = cv2.getPerspectiveTransform(startpoints, endpoints) + + # TODO: more robust way to crop image + rect = cv2.minAreaRect(endpoints) + bbox = cv2.boxPoints(rect).astype(dtype=np.int) + max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max() + min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min() + min_x, min_y = max(min_x, 0), max(min_y, 0) + + flags = get_interpolation() + img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE) + img = img[min_y:, min_x:] + return img + +class CVRescale(object): + + def __init__(self, factor=4, base_size=(128, 512)): + """ Define image scales using gaussian pyramid and rescale image to target scale. + + Args: + factor: the decayed factor from base size, factor=4 keeps target scale by default. + base_size: base size the build the bottom layer of pyramid + """ + if isinstance(factor, numbers.Number): + self.factor = round(sample_uniform(0, factor)) + elif isinstance(factor, (tuple, list)) and len(factor) == 2: + self.factor = round(sample_uniform(factor[0], factor[1])) + else: + raise Exception('factor must be number or list with length 2') + # assert factor is valid + self.base_h, self.base_w = base_size[:2] + + def __call__(self, img): + if self.factor == 0: return img + src_h, src_w = img.shape[:2] + cur_w, cur_h = self.base_w, self.base_h + scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation()) + for _ in range(self.factor): + scale_img = cv2.pyrDown(scale_img) + scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation()) + return scale_img + +class CVGaussianNoise(object): + def __init__(self, mean=0, var=20): + self.mean = mean + if isinstance(var, numbers.Number): + self.var = max(int(sample_asym(var)), 1) + elif isinstance(var, (tuple, list)) and len(var) == 2: + self.var = int(sample_uniform(var[0], var[1])) + else: + raise Exception('degree must be number or list with length 2') + + def __call__(self, img): + noise = np.random.normal(self.mean, self.var**0.5, img.shape) + img = np.clip(img + noise, 0, 255).astype(np.uint8) + return img + +class CVMotionBlur(object): + def __init__(self, degrees=12, angle=90): + if isinstance(degrees, numbers.Number): + self.degree = max(int(sample_asym(degrees)), 1) + elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: + self.degree = int(sample_uniform(degrees[0], degrees[1])) + else: + raise Exception('degree must be number or list with length 2') + self.angle = sample_uniform(-angle, angle) + + def __call__(self, img): + M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1) + motion_blur_kernel = np.zeros((self.degree, self.degree)) + motion_blur_kernel[self.degree // 2, :] = 1 + motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree)) + motion_blur_kernel = motion_blur_kernel / self.degree + img = cv2.filter2D(img, -1, motion_blur_kernel) + img = np.clip(img, 0, 255).astype(np.uint8) + return img + +class CVGeometry(object): + def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.), + shear=(45, 15), distortion=0.5, p=0.5): + self.p = p + type_p = random.random() + if type_p < 0.33: + self.transforms = CVRandomRotation(degrees=degrees) + elif type_p < 0.66: + self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) + else: + self.transforms = CVRandomPerspective(distortion=distortion) + + def __call__(self, img): + if random.random() < self.p: + img = np.array(img) + return Image.fromarray(self.transforms(img)) + else: return img + +class CVDeterioration(object): + def __init__(self, var, degrees, factor, p=0.5): + self.p = p + transforms = [] + if var is not None: + transforms.append(CVGaussianNoise(var=var)) + if degrees is not None: + transforms.append(CVMotionBlur(degrees=degrees)) + if factor is not None: + transforms.append(CVRescale(factor=factor)) + + random.shuffle(transforms) + transforms = Compose(transforms) + self.transforms = transforms + + def __call__(self, img): + if random.random() < self.p: + img = np.array(img) + return Image.fromarray(self.transforms(img)) + else: return img + + +class CVColorJitter(object): + def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5): + self.p = p + self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast, + saturation=saturation, hue=hue) + + def __call__(self, img): + if random.random() < self.p: return self.transforms(img) + else: return img diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1b7b5db1bc1dd191191c31b3e72228ccd1c4f7a1 --- /dev/null +++ b/utils.py @@ -0,0 +1,304 @@ +import logging +import os +import time + +import cv2 +import numpy as np +import torch +import yaml +from matplotlib import colors +from matplotlib import pyplot as plt +from torch import Tensor, nn +from torch.utils.data import ConcatDataset + +class CharsetMapper(object): + """A simple class to map ids into strings. + + It works only when the character set is 1:1 mapping between individual + characters and individual ids. + """ + + def __init__(self, + filename='', + max_length=30, + null_char=u'\u2591'): + """Creates a lookup table. + + Args: + filename: Path to charset file which maps characters to ids. + max_sequence_length: The max length of ids and string. + null_char: A unicode character used to replace '