diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..b4142b465b0247a43021d923f87e69b7d262c4c6
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,36 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.json filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..15c628e670d8824ca2c557bda6edc885c77a34b1
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,26 @@
+*.pyc
+*__pycache__*
+*core.*
+_ext
+tmp
+*.o*
+*~
+*.idea
+*.mp4
+*.avi
+*.h5
+*.pt
+*.egg-info
+
+# Build
+/build
+/dist
+
+# Virtual enviroment
+**/venv
+
+# Dataset images
+/spiga/data/databases
+venv/
+flagged
+assets/
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9a3b249e02e555ca4470135003163d1212eb2939
--- /dev/null
+++ b/README.md
@@ -0,0 +1,12 @@
+---
+title: SPIGA Face Alignment Headpose Estimator
+emoji: π
+colorFrom: pink
+colorTo: green
+sdk: gradio
+sdk_version: 3.22.1
+app_file: app.py
+pinned: false
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/SPIGA/.gitignore b/SPIGA/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8746724fe1de020bf5068d4f2f0d18cab2b9660e
--- /dev/null
+++ b/SPIGA/.gitignore
@@ -0,0 +1,23 @@
+*.pyc
+*__pycache__*
+*core.*
+_ext
+tmp
+*.o*
+*~
+*.idea
+*.mp4
+*.avi
+*.h5
+*.pt
+*.egg-info
+
+# Build
+/build
+/dist
+
+# Virtual enviroment
+**/venv
+
+# Dataset images
+/spiga/data/databases
diff --git a/SPIGA/LICENSE b/SPIGA/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d06b4d08ab920bff3c8e3bb8693cc4ba3ab03075
--- /dev/null
+++ b/SPIGA/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2022, aprados
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/SPIGA/MANIFEST.in b/SPIGA/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..fb5d16ae38bbce073e0bf8fc6a722fff99198128
--- /dev/null
+++ b/SPIGA/MANIFEST.in
@@ -0,0 +1,5 @@
+include spiga/data/annotations/**/db_info.json
+include spiga/data/models3D/*.txt
+include spiga/data/readme.md
+include spiga/eval/benchmark/readme.md
+
diff --git a/SPIGA/README.md b/SPIGA/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d3c7b7b94ff8293ef4de605af25c437303ee65a
--- /dev/null
+++ b/SPIGA/README.md
@@ -0,0 +1,217 @@
+# SPIGA: Shape Preserving Facial Landmarks with Graph Attention Networks.
+
+[](https://bmvc2022.mpi-inf.mpg.de/155/)
+[](https://arxiv.org/abs/2210.07233)
+[](https://badge.fury.io/py/spiga)
+[](LICENSE)
+[](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/video_demo.ipynb)
+
+This repository contains the source code of **SPIGA, a face alignment and headpose estimator** that takes advantage of the complementary benefits from CNN and GNN architectures producing plausible face shapes in presence of strong appearance changes.
+
+
+
+
+
+**It achieves top-performing results in:**
+
+[](https://paperswithcode.com/sota/pose-estimation-on-300w-full?p=shape-preserving-facial-landmarks-with-graph)
+[](https://paperswithcode.com/sota/head-pose-estimation-on-wflw?p=shape-preserving-facial-landmarks-with-graph)
+[](https://paperswithcode.com/sota/pose-estimation-on-merl-rav?p=shape-preserving-facial-landmarks-with-graph)
+[](https://paperswithcode.com/sota/face-alignment-on-merl-rav?p=shape-preserving-facial-landmarks-with-graph)
+[](https://paperswithcode.com/sota/face-alignment-on-wflw?p=shape-preserving-facial-landmarks-with-graph)
+[](https://paperswithcode.com/sota/face-alignment-on-300w-split-2?p=shape-preserving-facial-landmarks-with-graph)
+[](https://paperswithcode.com/sota/face-alignment-on-cofw-68?p=shape-preserving-facial-landmarks-with-graph)
+[](https://paperswithcode.com/sota/face-alignment-on-300w?p=shape-preserving-facial-landmarks-with-graph)
+
+
+## Setup
+The repository has been tested on Ubuntu 20.04 with CUDA 11.4, the latest version of cuDNN, Python 3.8 and Pytorch 1.12.1.
+To run the video analyzer demo or evaluate the algorithm, install the repository from the source code:
+
+```
+# Best practices:
+# 1. Create a virtual environment.
+# 2. Install Pytorch according to your CUDA version.
+# 3. Install SPIGA from source code:
+
+git clone https://github.com/andresprados/SPIGA.git
+cd spiga
+pip install -e .
+
+# To run the video analyzer demo install the extra requirements.
+pip install -e .[demo]
+```
+**Models:** By default, model weights are automatically downloaded on demand and stored at ```./spiga/models/weights/```.
+You can also download them from [Google Drive](https://drive.google.com/drive/folders/1olrkoiDNK_NUCscaG9BbO3qsussbDi7I?usp=sharing).
+
+***Note:*** All the callable files provide a detailed parser that describes the behaviour of the program and their inputs. Please, check the operational modes by using the extension ```--help```.
+
+## Inference and Demo
+We provide an inference framework for SPIGA available at ```./spiga/inference```. The models can be easily deployed
+in third-party projects by adding a few lines of code. Check out our inference and application tutorials
+for more information:
+
+
+
+Tutorials | Notebook |
+:---: | :---: |
+Image Inference Example | [](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/image_demo.ipynb) |
+Face Video Analyzer Demo | [](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/video_demo.ipynb) |
+
+
+
+### Face Video Analyzer Demo:
+The demo application provides a general framework for tracking, detecting and extracting features of human faces in images or videos.
+You can use the following commands to run the demo:
+
+```
+python ./spiga/demo/app.py \
+ [--input] \ # Webcam ID or Video Path. Dft: Webcam '0'.
+ [--dataset] \ # SPIGA pretrained weights per dataset. Dft: 'wflw'.
+ [--tracker] \ # Tracker name. Dft: 'RetinaSort'.
+ [--show] \ # Select the attributes of the face to be displayed. Dft: ['fps', 'face_id', 'landmarks', 'headpose']
+ [--save] \ # Save record.
+ [--noview] \ # Do not visualize window.
+ [--outpath] \ # Recorded output directory. Dft: './spiga/demo/outputs'
+ [--fps] \ # Frames per second.
+ [--shape] \ # Visualizer shape (W,H).
+```
+
+
+
+
+
+
+
+
+
+
+
+
+***Note:*** For more information check the [Demo Readme](spiga/demo/readme.md) or call the app parser ```--help```.
+
+
+## Dataloaders and Benchmarks
+This repository provides general-use tools for the task of face alignment and headpose estimation:
+
+* **Dataloaders:** Training and inference dataloaders are available at ```./spiga/data```.
+Including the data augmentation tools used for training SPIGA and data-visualizer to analyze the dataset images and features.
+For more information check the [Data Readme](spiga/data/readme.md) .
+
+* **Benchmark:** A common benchmark framework to test any algorithm in the task of face alignment and headpose estimation
+is available at ```./spiga/eval/benchmark```. For more information check the following Evaluation Section and the [Benchmark Readme](spiga/eval/benchmark/readme.md).
+
+**Datasets:** To run the data visualizers or the evaluation benchmark please download the dataset images from the official websites
+([300W](https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/),
+[AFLW](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/),
+[WFLW](https://wywu.github.io/projects/LAB/WFLW.html), [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/)).
+By default they should be saved following the next folder structure:
+```
+./spiga/data/databases/ # Default path can be updated by modifying 'db_img_path' in ./spiga/data/loaders/dl_config.py
+|
+ββββ/300w
+β ββββ /images
+β | /private
+β | /test
+| β /train
+|
+ββββ/cofw
+β ββββ /images
+|
+ββββ/aflw
+β ββββ /data
+| β /flickr
+|
+ββββ/wflw
+ ββββ /images
+```
+**Annotations:** We have stored for simplicity the datasets annotations directly in ```./spiga/data/annotations```. We strongly recommend to move them out of the repository if you plan to use it as a git directory.
+
+**Results:** Similar to the annotations problem, we have stored the SPIGA results in ```./spiga/eval/results/```. Remove them if need it.
+
+## Evaluation
+The models evaluation is divided in two scripts:
+
+**Results generation**: The script extracts the data alignments and headpose estimation from the desired ``` ``` trained network. Generating a ```./spiga/eval/results/results__test.json``` file which follows the same data structure defined by the dataset annotations.
+
+```
+python ./spiga/eval/results_gen.py
+```
+
+**Benchmark metrics**: The script generates the desired landmark or headpose estimation metrics. We have implemented an useful benchmark which allows you to test any model using a results file as input.
+
+```
+python ./spiga/eval/benchmark/evaluator.py /path/to/ --eval lnd pose -s
+```
+
+***Note:*** You will have to interactively select the NME_norm and other parameters in the terminal window.
+
+### Results Sum-up
+
+ WFLW Dataset
+
+|[](https://paperswithcode.com/sota/face-alignment-on-wflw?p=shape-preserving-facial-landmarks-with-graph)|NME_ioc|AUC_10|FR_10|NME_P90|NME_P95|NME_P99|
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+|full|4.060|60.558|2.080|6.766|8.199|13.071|
+|pose|7.141|35.312|11.656|10.684|13.334|26.890|
+|expression|4.457|57.968|2.229|7.023|8.148|22.388|
+|illumination|4.004|61.311|1.576|6.528|7.919|11.090|
+|makeup|3.809|62.237|1.456|6.320|8.289|11.564|
+|occlusion|4.952|53.310|4.484|8.091|9.929|16.439|
+|blur|4.650|55.310|2.199|7.311|8.693|14.421|
+
+
+
+ MERLRAV Dataset
+
+|[](https://paperswithcode.com/sota/face-alignment-on-merl-rav?p=shape-preserving-facial-landmarks-with-graph)|NME_bbox|AUC_7|FR_7|NME_P90|NME_P95|NME_P99|
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+|full|1.509|78.474|0.052|2.163|2.468|3.456|
+|frontal|1.616|76.964|0.091|2.246|2.572|3.621|
+|half_profile|1.683|75.966|0.000|2.274|2.547|3.397|
+|profile|1.191|82.990|0.000|1.735|2.042|2.878|
+
+
+
+ 300W Private Dataset
+
+|[](https://paperswithcode.com/sota/face-alignment-on-300w-split-2?p=shape-preserving-facial-landmarks-with-graph)|NME_bbox|AUC_7|FR_7|NME_P90|NME_P95|NME_P99|
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+|full|2.031|71.011|0.167|2.788|3.078|3.838|
+|indoor|2.035|70.959|0.333|2.726|3.007|3.712|
+|outdoor|2.027|37.174|0.000|2.824|3.217|3.838|
+
+
+
+ COFW68 Dataset
+
+|[](https://paperswithcode.com/sota/face-alignment-on-cofw-68?p=shape-preserving-facial-landmarks-with-graph)|NME_bbox|AUC_7|FR_7|NME_P90|NME_P95|NME_P99|
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+|full|2.517|64.050|0.000|3.439|4.066|5.558|
+
+
+
+ 300W Public Dataset
+
+|[](https://paperswithcode.com/sota/face-alignment-on-300w?p=shape-preserving-facial-landmarks-with-graph)|NME_ioc|AUC_8|FR_8|NME_P90|NME_P95|NME_P99|
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+|full|2.994|62.726|0.726|4.667|5.436|7.320|
+|common|2.587|44.201|0.000|3.710|4.083|5.215|
+|challenge|4.662|42.449|3.704|6.626|7.390|10.095|
+
+
+
+
+## BibTeX Citation
+```
+@inproceedings{Prados-Torreblanca_2022_BMVC,
+ author = {AndrΓ©s Prados-Torreblanca and JosΓ© M Buenaposada and Luis Baumela},
+ title = {Shape Preserving Facial Landmarks with Graph Attention Networks},
+ booktitle = {33rd British Machine Vision Conference 2022, {BMVC} 2022, London, UK, November 21-24, 2022},
+ publisher = {{BMVA} Press},
+ year = {2022},
+ url = {https://bmvc2022.mpi-inf.mpg.de/0155.pdf}
+}
+```
+
+
diff --git a/SPIGA/colab_tutorials/image_demo.ipynb b/SPIGA/colab_tutorials/image_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..847f4fe6d0dcbdc79f09a691302c127254b7305b
--- /dev/null
+++ b/SPIGA/colab_tutorials/image_demo.ipynb
@@ -0,0 +1,197 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyO+yWmNPw3eBl9Z5zvQYH17"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "gpuClass": "standard"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# SPIGA: Shape Preserving Facial Landmarks with Graph Attention Networks.\n",
+ "\n",
+ "[](https://github.com/andresprados/SPIGA)\n",
+ "[](https://bmvc2022.mpi-inf.mpg.de/155/)\n",
+ "[](https://arxiv.org/abs/2210.07233)\n",
+ "[](https://badge.fury.io/py/spiga)\n",
+ "[](LICENSE)\n",
+ "[](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/image_demo.ipynb)\n",
+ "\n",
+ "**SPIGA is a face alignment and headpose estimator** that takes advantage of the complementary benefits from CNN and GNN architectures producing plausible face shapes in presence of strong appearance changes. "
+ ],
+ "metadata": {
+ "id": "zYVrcsnLp7D0"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Image Inference Example:\n",
+ "SPIGA provides an inference framework that can be easily deployed \n",
+ "in third-party projects by adding a few lines of code.\n"
+ ],
+ "metadata": {
+ "id": "6VGcBElYwZQM"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "###Setup the repository and load image:"
+ ],
+ "metadata": {
+ "id": "Pxj4tBskykGV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Clone and setup the repository\n",
+ "!git clone https://github.com/andresprados/SPIGA.git\n",
+ "%cd SPIGA/\n",
+ "!pip install -e ."
+ ],
+ "metadata": {
+ "id": "N6Mvu13ZBg92"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import cv2\n",
+ "import json\n",
+ "import numpy as np\n",
+ "from google.colab.patches import cv2_imshow\n",
+ "\n",
+ "# Load image and bbox\n",
+ "image = cv2.imread(\"/content/SPIGA/assets/colab/image_sportsfan.jpg\")\n",
+ "with open('/content/SPIGA/assets/colab/bbox_sportsfan.json') as jsonfile:\n",
+ " bbox = json.load(jsonfile)['bbox']"
+ ],
+ "metadata": {
+ "id": "Lf0RwuFW9cSC"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Inference and visualization:"
+ ],
+ "metadata": {
+ "id": "E4u69_ssyxkY"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from spiga.inference.config import ModelConfig\n",
+ "from spiga.inference.framework import SPIGAFramework\n",
+ "\n",
+ "# Process image\n",
+ "dataset = 'wflw'\n",
+ "processor = SPIGAFramework(ModelConfig(dataset))\n",
+ "features = processor.inference(image, [bbox])\n"
+ ],
+ "metadata": {
+ "id": "Mck9eHXKYUxd"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import copy\n",
+ "from spiga.demo.visualize.plotter import Plotter\n",
+ "\n",
+ "# Prepare variables\n",
+ "x0,y0,w,h = bbox\n",
+ "canvas = copy.deepcopy(image)\n",
+ "landmarks = np.array(features['landmarks'][0])\n",
+ "headpose = np.array(features['headpose'][0])\n",
+ "\n",
+ "# Plot features\n",
+ "plotter = Plotter()\n",
+ "canvas = plotter.landmarks.draw_landmarks(canvas, landmarks)\n",
+ "canvas = plotter.hpose.draw_headpose(canvas, [x0,y0,x0+w,y0+h], headpose[:3], headpose[3:], euler=True)\n",
+ "\n",
+ "# Show image results\n",
+ "(h, w) = canvas.shape[:2]\n",
+ "canvas = cv2.resize(canvas, (512, int(h*512/w)))\n",
+ "cv2_imshow(canvas)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 401
+ },
+ "id": "I6o6piO-Dju_",
+ "outputId": "123db829-e9f6-4257-e6e3-ff9f4be8e18c"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "image/png": "\n"
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Citation\n",
+ "If you like our work dont forget to cite us!\n",
+ "\n",
+ "```\n",
+ "@inproceedings{Prados-Torreblanca_2022_BMVC,\n",
+ " author = {AndrΓ©s Prados-Torreblanca and JosΓ© M Buenaposada and Luis Baumela},\n",
+ " title = {Shape Preserving Facial Landmarks with Graph Attention Networks},\n",
+ " booktitle = {33rd British Machine Vision Conference 2022, {BMVC} 2022, London, UK, November 21-24, 2022},\n",
+ " publisher = {{BMVA} Press},\n",
+ " year = {2022},\n",
+ " url = {https://bmvc2022.mpi-inf.mpg.de/0155.pdf}\n",
+ "}\n",
+ "```"
+ ],
+ "metadata": {
+ "id": "QyWfeEzECTQg"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Related Notebooks\n",
+ "\n",
+ "Tutorials | Notebook |\n",
+ ":---| :---: |\n",
+ "Image Inference Example | [](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/image_demo.ipynb) |\n",
+ "Face Video Analyzer Demo | [](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/video_demo.ipynb) |"
+ ],
+ "metadata": {
+ "id": "9cEL-x2MEHrW"
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/SPIGA/colab_tutorials/video_demo.ipynb b/SPIGA/colab_tutorials/video_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d5c34e7ecf80035250f65d1336aeb7f374850887
--- /dev/null
+++ b/SPIGA/colab_tutorials/video_demo.ipynb
@@ -0,0 +1,215 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyOF6rc4QkOxYUF0EKnbRAyL"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "gpuClass": "standard"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# SPIGA: Shape Preserving Facial Landmarks with Graph Attention Networks.\n",
+ "\n",
+ "[](https://github.com/andresprados/SPIGA)\n",
+ "[](https://bmvc2022.mpi-inf.mpg.de/155/)\n",
+ "[](https://arxiv.org/abs/2210.07233)\n",
+ "[](https://badge.fury.io/py/spiga)\n",
+ "[](LICENSE)\n",
+ "[](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/video_demo.ipynb)\n",
+ "\n",
+ "**SPIGA is a face alignment and headpose estimator** that takes advantage of the complementary benefits from CNN and GNN architectures producing plausible face shapes in presence of strong appearance changes. \n"
+ ],
+ "metadata": {
+ "id": "zYVrcsnLp7D0"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Video Analyzer Demo:\n",
+ "SPIGA provides a general framework for tracking, detecting and extracting features of human faces in images or videos."
+ ],
+ "metadata": {
+ "id": "6VGcBElYwZQM"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "###Setup the repository:\n"
+ ],
+ "metadata": {
+ "id": "Pxj4tBskykGV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Clone and setup the repository\n",
+ "!git clone https://github.com/andresprados/SPIGA.git\n",
+ "%cd SPIGA/\n",
+ "!pip install -e .[demo]"
+ ],
+ "metadata": {
+ "id": "N6Mvu13ZBg92"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Record Webcam Video:"
+ ],
+ "metadata": {
+ "id": "E4u69_ssyxkY"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import colab_tutorials.video_tools.record as vid_util\n",
+ "\n",
+ "webcam_video_path = '/content/test.mp4'\n",
+ "vid_util.record_video(webcam_video_path)"
+ ],
+ "metadata": {
+ "id": "Mck9eHXKYUxd"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Process Video with SPIGA Framework:\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "JefR-lla9xEs"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "from spiga.demo.app import video_app\n",
+ "\n",
+ "# MP4 input path: Webcam recorded video or uploaded one.\n",
+ "# video_path = '/content/'\n",
+ "video_path = webcam_video_path\n",
+ "output_path= '/content/output' # Processed video storage\n",
+ "\n",
+ "# Process video\n",
+ "video_app(video_path,\n",
+ " spiga_dataset='wflw', # Choices=['wflw', '300wpublic', '300wprivate', 'merlrav']\n",
+ " tracker='RetinaSort', # Choices=['RetinaSort', 'RetinaSort_Res50']\n",
+ " save=True,\n",
+ " output_path=output_path,\n",
+ " visualize=False,\n",
+ " plot=['fps', 'face_id', 'landmarks', 'headpose'])\n",
+ "\n",
+ "\n",
+ "# Convert Opencv video to Colab readable format\n",
+ "video_name = video_path.split('/')[-1]\n",
+ "video_output_path = os.path.join(output_path, video_name)\n",
+ "video_colab_path = os.path.join(output_path, video_name[:-4]+'_colab.mp4')\n",
+ "!ffmpeg -i '{video_output_path}' '{video_colab_path}'"
+ ],
+ "metadata": {
+ "id": "wffMRr2T8Yvk"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Results Visualization:"
+ ],
+ "metadata": {
+ "id": "v0k3qZ3YDlEw"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import colab_tutorials.video_tools.record as vid_util\n",
+ "\n",
+ "# Display video\n",
+ "vid_util.show_video(video_colab_path)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 471
+ },
+ "id": "I6o6piO-Dju_",
+ "outputId": "6f520b13-c6e7-4abe-8660-51d32f66524a"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ " "
+ ]
+ },
+ "metadata": {},
+ "execution_count": 10
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Citation\n",
+ "If you like our work dont forget to cite us!\n",
+ "\n",
+ "```\n",
+ "@inproceedings{Prados-Torreblanca_2022_BMVC,\n",
+ " author = {AndrΓ©s Prados-Torreblanca and JosΓ© M Buenaposada and Luis Baumela},\n",
+ " title = {Shape Preserving Facial Landmarks with Graph Attention Networks},\n",
+ " booktitle = {33rd British Machine Vision Conference 2022, {BMVC} 2022, London, UK, November 21-24, 2022},\n",
+ " publisher = {{BMVA} Press},\n",
+ " year = {2022},\n",
+ " url = {https://bmvc2022.mpi-inf.mpg.de/0155.pdf}\n",
+ "}\n",
+ "```"
+ ],
+ "metadata": {
+ "id": "ZvuQLJPDGejs"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Related Notebooks\n",
+ "\n",
+ "Tutorials | Notebook |\n",
+ ":---| :---: |\n",
+ "Image Inference Example | [](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/image_demo.ipynb) |\n",
+ "Face Video Analyzer Demo | [](https://colab.research.google.com/github/andresprados/SPIGA/blob/main/colab_tutorials/video_demo.ipynb) |"
+ ],
+ "metadata": {
+ "id": "aQ_G5unzGmto"
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/SPIGA/colab_tutorials/video_tools/record.py b/SPIGA/colab_tutorials/video_tools/record.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5a0168c201fb2867d6d08fc850fd9ccef7f46b7
--- /dev/null
+++ b/SPIGA/colab_tutorials/video_tools/record.py
@@ -0,0 +1,78 @@
+from IPython.display import display, Javascript, HTML
+from google.colab.output import eval_js
+from base64 import b64decode, b64encode
+
+
+def record_video(filename):
+ js = Javascript("""
+ async function recordVideo() {
+ const options = { mimeType: "video/webm; codecs=vp9" };
+ const div = document.createElement('div');
+ const capture = document.createElement('button');
+ const stopCapture = document.createElement("button");
+
+ capture.textContent = "Start Recording";
+ capture.style.background = "orange";
+ capture.style.color = "white";
+
+ stopCapture.textContent = "Stop Recording";
+ stopCapture.style.background = "red";
+ stopCapture.style.color = "white";
+ div.appendChild(capture);
+
+ const video = document.createElement('video');
+ const recordingVid = document.createElement("video");
+ video.style.display = 'block';
+
+ const stream = await navigator.mediaDevices.getUserMedia({audio:true, video: true});
+
+ let recorder = new MediaRecorder(stream, options);
+ document.body.appendChild(div);
+ div.appendChild(video);
+
+ video.srcObject = stream;
+ video.muted = true;
+
+ await video.play();
+
+ google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);
+
+ await new Promise((resolve) => {
+ capture.onclick = resolve;
+ });
+ recorder.start();
+ capture.replaceWith(stopCapture);
+
+ await new Promise((resolve) => stopCapture.onclick = resolve);
+ recorder.stop();
+ let recData = await new Promise((resolve) => recorder.ondataavailable = resolve);
+ let arrBuff = await recData.data.arrayBuffer();
+
+ // stop the stream and remove the video element
+ stream.getVideoTracks()[0].stop();
+ div.remove();
+
+ let binaryString = "";
+ let bytes = new Uint8Array(arrBuff);
+ bytes.forEach((byte) => {
+ binaryString += String.fromCharCode(byte);
+ })
+ return btoa(binaryString);
+ }
+ """)
+ try:
+ display(js)
+ data = eval_js('recordVideo({})')
+ binary = b64decode(data)
+ with open(filename, "wb") as video_file:
+ video_file.write(binary)
+ print(f"Finished recording video at:{filename}")
+ except Exception as err:
+ print(str(err))
+
+
+def show_video(video_path, video_width=600):
+ video_file = open(video_path, "r+b").read()
+
+ video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
+ return HTML(f""" """)
\ No newline at end of file
diff --git a/SPIGA/colab_tutorials/video_tools/utils.py b/SPIGA/colab_tutorials/video_tools/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..924bccf06cd0db5b14024b45828c4448a6fce823
--- /dev/null
+++ b/SPIGA/colab_tutorials/video_tools/utils.py
@@ -0,0 +1,52 @@
+import numpy as np
+import PIL
+import io
+import cv2
+from base64 import b64decode, b64encode
+
+
+def js_to_image(js_reply):
+ """
+ Convert the JavaScript object into an OpenCV image.
+
+ @param js_reply: JavaScript object containing image from webcam
+ @return img: OpenCV BGR image
+ """
+ # decode base64 image
+ image_bytes = b64decode(js_reply.split(',')[1])
+ # convert bytes to numpy array
+ jpg_as_np = np.frombuffer(image_bytes, dtype=np.uint8)
+ # decode numpy array into OpenCV BGR image
+ img = cv2.imdecode(jpg_as_np, flags=1)
+
+ return img
+
+
+def bbox_to_bytes(bbox_array):
+ """
+ Convert OpenCV Rectangle bounding box image into base64 byte string to be overlayed on video stream.
+
+ @param bbox_array: Numpy array (pixels) containing rectangle to overlay on video stream.
+ @return bbox_bytes: Base64 image byte string
+ """
+ # convert array into PIL image
+ bbox_PIL = PIL.Image.fromarray(bbox_array, 'RGBA')
+ iobuf = io.BytesIO()
+ # format bbox into png for return
+ bbox_PIL.save(iobuf, format='png')
+ # format return string
+ bbox_bytes = 'data:image/png;base64,{}'.format((str(b64encode(iobuf.getvalue()), 'utf-8')))
+ return bbox_bytes
+
+
+def image_to_bytes(image):
+ """
+ Convert OpenCV image into base64 byte string to be overlayed on video stream.
+
+ @param image: Input image.
+ @return img_bytes: Base64 image byte string.
+ """
+ ret, buffer = cv2.imencode('.jpg', image)
+ jpg_as_text = b64encode(buffer).decode('utf-8')
+ img_bytes = f'data:image/jpeg;base64,{jpg_as_text}'
+ return img_bytes
\ No newline at end of file
diff --git a/SPIGA/pyproject.toml b/SPIGA/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..2a84a692b49a5c2e16d71982286cf10c3aa68d37
--- /dev/null
+++ b/SPIGA/pyproject.toml
@@ -0,0 +1,7 @@
+[build-system]
+requires = [
+ "setuptools>=64.0.0",
+ "wheel",
+]
+build-backend = "setuptools.build_meta"
+
diff --git a/SPIGA/requirements.txt b/SPIGA/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b3c1378c0b679aee593be8c766a0517c11c0ba9c
--- /dev/null
+++ b/SPIGA/requirements.txt
@@ -0,0 +1,9 @@
+matplotlib>=3.2.1
+numpy>=1.18.2
+opencv-python>=4.2.0.32
+Pillow>=7.0.0
+torch>=1.4.0
+torchvision>=0.5.0
+torchaudio
+scipy
+scikit-learn
diff --git a/SPIGA/setup.cfg b/SPIGA/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..d5cd0f1c07c21389776337c61f5b5772c4cad950
--- /dev/null
+++ b/SPIGA/setup.cfg
@@ -0,0 +1,74 @@
+# Configuration of the Python project
+
+# Configure setup.py
+[metadata]
+name = spiga
+version = 0.0.6
+author = Andres Prados Torreblanca
+author_email = andresprator@gmail.com
+description = SPIGA: Shape Preserving Facial Landmarks with Graph Attention Networks
+long_description = file: README.md
+long_description_content_type= text/markdown
+license = BSD-3-Clause
+license_files = LICENSE
+url = https://github.com/andresprados/SPIGA
+project_urls =
+ Homepage = https://bmvc2022.mpi-inf.mpg.de/155/
+ SPIGA Paper = https://bmvc2022.mpi-inf.mpg.de/0155.pdf
+ Bug Tracker = https://github.com/andresprados/SPIGA/issues
+keywords =
+ Computer Vision
+ Face Alignment
+ Head Pose Estimation
+ Pytorch
+ CNN
+ GNN
+ BMVC2022
+ WFLW
+ 300W
+ Merlrav
+ COFW
+
+classifiers =
+ Development Status :: 4 - Beta
+ Intended Audience :: Developers
+ Intended Audience :: Science/Research
+ Intended Audience :: Education
+ Operating System :: OS Independent
+ Environment :: GPU
+ Environment :: Console
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Topic :: Scientific/Engineering :: Artificial Intelligence
+ Topic :: Scientific/Engineering :: Image Processing
+ Topic :: Software Development :: Libraries
+ Topic :: Software Development :: Libraries :: Python Modules
+
+
+[options]
+packages = find:
+include_package_data = True
+python_requires = >= 3.6
+install_requires =
+ matplotlib>=3.2.1
+ numpy>=1.18.2
+ opencv-python>=4.2.0.32
+ Pillow>=7.0.0
+ torch>=1.4.0
+ torchvision>=0.5.0
+ torchaudio
+ scipy
+ scikit-learn
+
+[options.extras_require]
+demo =
+ retinaface-py>=0.0.2
+ sort-tracker-py>= 1.0.2
+
+[options.packages.find]
+exclude =
+ spiga.eval.results*
+ colab_tutorials*
diff --git a/SPIGA/spiga/__init__.py b/SPIGA/spiga/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/data/__init__.py b/SPIGA/spiga/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/data/annotations/300wprivate/db_info.json b/SPIGA/spiga/data/annotations/300wprivate/db_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..026b8df7468581a8451fda952adbf7c95cec15d9
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/300wprivate/db_info.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f43605118e19366321fa01e24711cad6a5eefdc752bc4015ecdaf4f63018f477
+size 8581
diff --git a/SPIGA/spiga/data/annotations/300wprivate/test.json b/SPIGA/spiga/data/annotations/300wprivate/test.json
new file mode 100644
index 0000000000000000000000000000000000000000..69cf8472bc518128058fe42f683e8825fd38dbd3
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/300wprivate/test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3aaa393fae27480c6779405082dee93241de5cfee8990c88d2f336d877c14753
+size 1383019
diff --git a/SPIGA/spiga/data/annotations/300wprivate/train.json b/SPIGA/spiga/data/annotations/300wprivate/train.json
new file mode 100644
index 0000000000000000000000000000000000000000..1d9835845de2ec726aa41a23d5986597d9e26141
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/300wprivate/train.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdadec43ccfe4b228ce46ff79c6b09c3b9a1f167294e021b6c9fcb592e1f909f
+size 9940676
diff --git a/SPIGA/spiga/data/annotations/300wpublic/db_info.json b/SPIGA/spiga/data/annotations/300wpublic/db_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..d78d653c7ac7bb2b90d2a7296248a6e49277d5b1
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/300wpublic/db_info.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8612f49ab40b9d9e863d4ac3b2f0b84eb13ea38a885137c966eeccd6150dbb79
+size 11680
diff --git a/SPIGA/spiga/data/annotations/300wpublic/test.json b/SPIGA/spiga/data/annotations/300wpublic/test.json
new file mode 100644
index 0000000000000000000000000000000000000000..609f31098509c704532bcaa79602bc5d076e7da1
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/300wpublic/test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4147e9924e7c040e647c7f39a50f56324a51383d3c52d4d056694211f36d0b0
+size 1865129
diff --git a/SPIGA/spiga/data/annotations/300wpublic/train.json b/SPIGA/spiga/data/annotations/300wpublic/train.json
new file mode 100644
index 0000000000000000000000000000000000000000..8bd4124c34132ea7d812933c30a87cf16d5298f3
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/300wpublic/train.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d432d230fc093e87cc402b33c23acf943b1552ef17453405b75f83db01b275fa
+size 8157361
diff --git a/SPIGA/spiga/data/annotations/cofw68/db_info.json b/SPIGA/spiga/data/annotations/cofw68/db_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..5fdbe527879e9fd6e41d1295dc692ab62abb6b79
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/cofw68/db_info.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2087f472b64f914ba8db2c50f2c79b22af96b376abdb01bfce87d0571ba3c235
+size 5611
diff --git a/SPIGA/spiga/data/annotations/cofw68/test.json b/SPIGA/spiga/data/annotations/cofw68/test.json
new file mode 100644
index 0000000000000000000000000000000000000000..d91f26676c4f5819aa8585dad85ad3d386618a64
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/cofw68/test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87325157ca578b28fa8117b9e02fd09768c1db588a0878c546ced9a266e51061
+size 1897814
diff --git a/SPIGA/spiga/data/annotations/merlrav/db_info.json b/SPIGA/spiga/data/annotations/merlrav/db_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..74e87085d428290d8e8a6589ec3faa74625a0bec
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/merlrav/db_info.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad7bf54f27fe816c0500f9318aa1df8a4ae90758faebbf1ba64d3ea037185382
+size 37851
diff --git a/SPIGA/spiga/data/annotations/merlrav/test.json b/SPIGA/spiga/data/annotations/merlrav/test.json
new file mode 100644
index 0000000000000000000000000000000000000000..26d627957734ae99ad43c3fba290d23ac933066f
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/merlrav/test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd1e358a29f95f869fa5fd09cb37c335a2220b5b478565459221921540aeb236
+size 11312227
diff --git a/SPIGA/spiga/data/annotations/merlrav/train.json b/SPIGA/spiga/data/annotations/merlrav/train.json
new file mode 100644
index 0000000000000000000000000000000000000000..df943ed607c514a8389f2225a030e4be23d1fe18
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/merlrav/train.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e8bfb36c27776d8386dfbee682af65332b1dad03faaedb9533b6b39acb230e9
+size 43323888
diff --git a/SPIGA/spiga/data/annotations/wflw/db_info.json b/SPIGA/spiga/data/annotations/wflw/db_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..1f60eef6cf3b8c4d63c3ca73eb95992b2a56c348
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/wflw/db_info.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a0d623edbf37625b180d596b71f753263f2b7c81e44af0b10023097f04e56958
+size 26829
diff --git a/SPIGA/spiga/data/annotations/wflw/test.json b/SPIGA/spiga/data/annotations/wflw/test.json
new file mode 100644
index 0000000000000000000000000000000000000000..7d4be9a6ec48158c03234cf7f54e521363a0c712
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/wflw/test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:306005740c464241d50a0aff00821215c2edbd913c739e4666f5e0dbca6f031c
+size 9755013
diff --git a/SPIGA/spiga/data/annotations/wflw/train.json b/SPIGA/spiga/data/annotations/wflw/train.json
new file mode 100644
index 0000000000000000000000000000000000000000..0e156062bd61dbd879355517ecb9fd05ffd4179c
--- /dev/null
+++ b/SPIGA/spiga/data/annotations/wflw/train.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95dcb0a2479e759da5e7051ad5c62eae90af3987614cfb0731802139f24e210a
+size 28374603
diff --git a/SPIGA/spiga/data/loaders/__init__.py b/SPIGA/spiga/data/loaders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/data/loaders/alignments.py b/SPIGA/spiga/data/loaders/alignments.py
new file mode 100644
index 0000000000000000000000000000000000000000..9467ec9308475e876346efe199e1973eb154cd88
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/alignments.py
@@ -0,0 +1,158 @@
+import os
+import json
+import cv2
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from spiga.data.loaders.transforms import get_transformers
+
+
+class AlignmentsDataset(Dataset):
+ '''Loads datasets of images with landmarks and bounding boxes.
+ '''
+
+ def __init__(self,
+ database,
+ json_file,
+ images_dir,
+ image_size=(128, 128),
+ transform=None,
+ indices=None,
+ debug=False):
+ """
+
+ :param database: class DatabaseStruct containing all the specifics of the database
+
+ :param json_file: path to the json file which contains the names of the images, landmarks, bounding boxes, etc
+
+ :param images_dir: path of the directory containing the images.
+
+ :param image_size: tuple like e.g. (128, 128)
+
+ :param transform: composition of transformations that will be applied to the samples.
+
+ :param debug_mode: bool if True, loads a very reduced_version of the dataset for debugging purposes.
+
+ :param indices: If it is a list of indices, allows to work with the subset of
+ items specified by the list. If it is None, the whole set is used.
+ """
+
+ self.database = database
+ self.images_dir = images_dir
+ self.transform = transform
+ self.image_size = image_size
+ self.indices = indices
+ self._imgs_dict = None
+ self.debug = debug
+
+ with open(json_file) as jsonfile:
+ self.data = json.load(jsonfile)
+
+ def __len__(self):
+ '''Returns the length of the dataset
+ '''
+ if self.indices is None:
+ return len(self.data)
+ else:
+ return len(self.indices)
+
+ def __getitem__(self, sample_idx):
+ '''Returns sample of the dataset of index idx'''
+
+ # To allow work with a subset
+ if self.indices is not None:
+ sample_idx = self.indices[sample_idx]
+
+ # Load sample image
+ img_name = os.path.join(self.images_dir, self.data[sample_idx]['imgpath'])
+ if not self._imgs_dict:
+ image_cv = cv2.imread(img_name)
+ else:
+ image_cv = self._imgs_dict[sample_idx]
+
+ # Some images are B&W. We make sure that any image has three channels.
+ if len(image_cv.shape) == 2:
+ image_cv = np.repeat(image_cv[:, :, np.newaxis], 3, axis=-1)
+
+ # Some images have alpha channel
+ image_cv = image_cv[:, :, :3]
+
+ image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
+ image = Image.fromarray(image_cv)
+
+ # Load sample anns
+ ids = np.array(self.data[sample_idx]['ids'])
+ landmarks = np.array(self.data[sample_idx]['landmarks'])
+ bbox = np.array(self.data[sample_idx]['bbox'])
+ vis = np.array(self.data[sample_idx]['visible'])
+ headpose = self.data[sample_idx]['headpose']
+
+ # Generate bbox if need it
+ if bbox is None:
+ # Compute bbox using landmarks
+ aux = landmarks[vis == 1.0]
+ bbox = np.zeros(4)
+ bbox[0] = min(aux[:, 0])
+ bbox[1] = min(aux[:, 1])
+ bbox[2] = max(aux[:, 0]) - bbox[0]
+ bbox[3] = max(aux[:, 1]) - bbox[1]
+
+ # Clean and mask landmarks
+ mask_ldm = np.ones(self.database.num_landmarks)
+ if not self.database.ldm_ids == ids.tolist():
+ new_ldm = np.zeros((self.database.num_landmarks, 2))
+ new_vis = np.zeros(self.database.num_landmarks)
+ xyv = np.hstack((landmarks, vis[np.newaxis,:].T))
+ ids_dict = dict(zip(ids.astype(int).astype(str), xyv))
+
+ for pos, identifier in enumerate(self.database.ldm_ids):
+ if str(identifier) in ids_dict:
+ x, y, v = ids_dict[str(identifier)]
+ new_ldm[pos] = [x,y]
+ new_vis[pos] = v
+ else:
+ mask_ldm[pos] = 0
+ landmarks = new_ldm
+ vis = new_vis
+
+ sample = {'image': image,
+ 'sample_idx': sample_idx,
+ 'imgpath': img_name,
+ 'ids_ldm': np.array(self.database.ldm_ids),
+ 'bbox': bbox,
+ 'bbox_raw': bbox,
+ 'landmarks': landmarks,
+ 'visible': vis.astype(np.float64),
+ 'mask_ldm': mask_ldm,
+ 'imgpath_local': self.data[sample_idx]['imgpath'],
+ }
+
+ if self.debug:
+ sample['landmarks_ori'] = landmarks
+ sample['visible_ori'] = vis.astype(np.float64)
+ sample['mask_ldm_ori'] = mask_ldm
+ if headpose is not None:
+ sample['headpose_ori'] = np.array(headpose)
+
+ if self.transform:
+ sample = self.transform(sample)
+
+ return sample
+
+
+def get_dataset(data_config, pretreat=None, debug=False):
+
+ augmentors = get_transformers(data_config)
+ if pretreat is not None:
+ augmentors.append(pretreat)
+
+ dataset = AlignmentsDataset(data_config.database,
+ data_config.anns_file,
+ data_config.image_dir,
+ image_size=data_config.image_size,
+ transform=transforms.Compose(augmentors),
+ indices=data_config.ids,
+ debug=debug)
+ return dataset
diff --git a/SPIGA/spiga/data/loaders/augmentors/__init__.py b/SPIGA/spiga/data/loaders/augmentors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/data/loaders/augmentors/boundary.py b/SPIGA/spiga/data/loaders/augmentors/boundary.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c10ec7f13c8ac12c68114ddd7b1e4b25a545689
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/augmentors/boundary.py
@@ -0,0 +1,122 @@
+import numpy as np
+from scipy import interpolate
+import cv2
+
+
+class AddBoundary(object):
+ def __init__(self, num_landmarks=68, map_size=64, sigma=1, min_dpi=64):
+ self.num_landmarks = num_landmarks
+ self.sigma = sigma
+
+ if isinstance(map_size, (tuple, list)):
+ self.width = map_size[0]
+ self.height = map_size[1]
+ else:
+ self.width = map_size
+ self.height = map_size
+
+ if max(map_size) > min_dpi:
+ self.dpi = max(map_size)
+ else:
+ self.dpi = min_dpi
+
+ self.fig_size =[self.height/self.dpi, self.width/self.dpi]
+
+ def __call__(self, sample):
+ landmarks = sample['landmarks_float']
+ mask_lnd = sample['mask_ldm_float']
+ boundaries = self.get_dataset_boundaries(landmarks, mask_lnd)
+ functions = {}
+
+ for key, points in boundaries.items():
+ if len(points) != 0:
+ temp = points[0]
+ new_points = points[0:1, :]
+ for point in points[1:]:
+ if point[0] == temp[0] and point[1] == temp[1]:
+ continue
+ else:
+ new_points = np.concatenate((new_points, np.expand_dims(point, 0)), axis=0)
+ temp = point
+
+ points = new_points
+ if points.shape[0] == 1:
+ points = np.concatenate((points, points+0.001), axis=0)
+ k = min(4, points.shape[0])
+ functions[key] = interpolate.splprep([points[:, 0], points[:, 1]], k=k-1,s=0)
+
+ boundary_maps = np.zeros((len(boundaries), self.height, self.width))
+ for i_map, key in enumerate(functions.keys()):
+ boundary_map = np.zeros((self.height, self.width))
+ xnew = np.arange(0, 1, 1/self.dpi)
+ out = interpolate.splev(xnew, functions[key][0], der=0)
+
+ out = np.round(out).astype(int).transpose()
+ out = out[out[:, 0] < self.height]
+ out = out[out[:, 1] < self.width]
+ boundary_map[out[:,1], out[:,0]]= 255
+
+ # Smooth
+ sigma = self.sigma
+ temp = 255 - boundary_map.astype(np.uint8)
+ temp = cv2.distanceTransform(temp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
+ temp = temp.astype(np.float32)
+ temp = np.where(temp < 3*sigma, np.exp(-(temp*temp)/(2*sigma*sigma)), 0 )
+ boundary_maps[i_map] = temp
+
+ sample['boundary'] = boundary_maps
+ return sample
+
+ def get_dataset_boundaries(self, landmarks, mask_lnd):
+ boundaries = {}
+ if self.num_landmarks == 68:
+ cheek = landmarks[0:17]
+ boundaries['cheek'] = cheek[mask_lnd[0:17] > 0]
+ left_eyebrow = landmarks[17:22]
+ boundaries['left_eyebrow'] = left_eyebrow[mask_lnd[17:22] > 0]
+ right_eyebrow = landmarks[22:27]
+ boundaries['right_eyebrow'] = right_eyebrow[mask_lnd[22:27] > 0]
+ nose = landmarks[27:31]
+ boundaries['nose'] = nose[mask_lnd[27:31] > 0]
+ nose_bot = landmarks[31:36]
+ boundaries['nose_bot'] = nose_bot[mask_lnd[31:36] > 0]
+ uper_left_eyelid = landmarks[36:40]
+ boundaries['upper_left_eyelid'] = uper_left_eyelid[mask_lnd[36:40] > 0]
+ lower_left_eyelid = np.array([landmarks[i] for i in [36, 41, 40, 39]])
+ lower_left_eyelid_mask = np.array([mask_lnd[i] for i in [36, 41, 40, 39]])
+ boundaries['lower_left_eyelid'] = lower_left_eyelid[lower_left_eyelid_mask > 0]
+ upper_right_eyelid = landmarks[42:46]
+ boundaries['upper_right_eyelid'] = upper_right_eyelid[mask_lnd[42:46] > 0]
+ lower_right_eyelid = np.array([landmarks[i] for i in [42, 47, 46, 45]])
+ lower_right_eyelid_mask = np.array([mask_lnd[i] for i in [42, 47, 46, 45]])
+ boundaries['lower_right_eyelid'] = lower_right_eyelid[lower_right_eyelid_mask > 0]
+ upper_outer_lip = landmarks[48:55]
+ boundaries['upper_outer_lip'] = upper_outer_lip[mask_lnd[48:55] > 0]
+ lower_outer_lip = np.array([landmarks[i] for i in [48, 59, 58, 57, 56, 55, 54]])
+ lower_outer_lip_mask = np.array([mask_lnd[i] for i in [48, 59, 58, 57, 56, 55, 54]])
+ boundaries['lower_outer_lip'] = lower_outer_lip[lower_outer_lip_mask > 0]
+ upper_inner_lip = np.array([landmarks[i] for i in [60, 61, 62, 63, 64]])
+ upper_inner_lip_mask = np.array([mask_lnd[i] for i in [60, 61, 62, 63, 64]])
+ boundaries['upper_inner_lip'] = upper_inner_lip[upper_inner_lip_mask > 0]
+ lower_inner_lip = np.array([landmarks[i] for i in [60, 67, 66, 65, 64]])
+ lower_inner_lip_mask = np.array([mask_lnd[i] for i in [60, 67, 66, 65, 64]])
+ boundaries['lower_inner_lip'] = lower_inner_lip[lower_inner_lip_mask > 0]
+
+ elif self.num_landmarks == 98:
+ boundaries['cheek'] = landmarks[0:33]
+ boundaries['upper_left_eyebrow'] = landmarks[33:38]
+ boundaries['lower_left_eyebrow'] = np.array([landmarks[i] for i in [33, 41, 40, 39, 38]])
+ boundaries['upper_right_eyebrow'] = landmarks[42:47]
+ boundaries['lower_right_eyebrow'] = landmarks[46:51]
+ boundaries['nose'] = landmarks[51:55]
+ boundaries['nose_bot'] = landmarks[55:60]
+ boundaries['upper_left_eyelid'] = landmarks[60:65]
+ boundaries['lower_left_eyelid'] = np.array([landmarks[i] for i in [60, 67, 66, 65, 64]])
+ boundaries['upper_right_eyelid'] = landmarks[68:73]
+ boundaries['lower_right_eyelid'] = np.array([landmarks[i] for i in [68, 75, 74, 73, 72]])
+ boundaries['upper_outer_lip'] = landmarks[76:83]
+ boundaries['lower_outer_lip'] = np.array([landmarks[i] for i in [76, 87, 86, 85, 84, 83, 82]])
+ boundaries['upper_inner_lip'] = np.array([landmarks[i] for i in [88, 89, 90, 91, 92]])
+ boundaries['lower_inner_lip'] = np.array([landmarks[i] for i in [88, 95, 94, 93, 92]])
+
+ return boundaries
diff --git a/SPIGA/spiga/data/loaders/augmentors/heatmaps.py b/SPIGA/spiga/data/loaders/augmentors/heatmaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc3ede432e5ca3ad2d0fd8a5972426bf59e9c16
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/augmentors/heatmaps.py
@@ -0,0 +1,39 @@
+import numpy as np
+
+
+class Heatmaps:
+
+ def __init__(self, num_maps, map_size, sigma, stride=1, norm=True):
+ self.num_maps = num_maps
+ self.sigma = sigma
+ self.double_sigma_pw2 = 2*sigma*sigma
+ self.doublepi_sigma_pw2 = self.double_sigma_pw2 * np.pi
+ self.stride = stride
+ self.norm = norm
+
+ if isinstance(map_size, (tuple, list)):
+ self.width = map_size[0]
+ self.height = map_size[1]
+ else:
+ self.width = map_size
+ self.height = map_size
+
+ grid_x = np.arange(self.width) * stride + stride / 2 - 0.5
+ self.grid_x = np.repeat(grid_x.reshape(1, self.width), self.num_maps, axis=0)
+ grid_y = np.arange(self.height) * stride + stride / 2 - 0.5
+ self.grid_y = np.repeat(grid_y.reshape(1, self.height), self.num_maps, axis=0)
+
+ def __call__(self, sample):
+ landmarks = sample['landmarks']
+ landmarks = landmarks[-self.num_maps:]
+
+ # Heatmap generation
+ exp_x = np.exp(-(self.grid_x - landmarks[:, 0].reshape(-1, 1)) ** 2 / self.double_sigma_pw2)
+ exp_y = np.exp(-(self.grid_y - landmarks[:, 1].reshape(-1, 1)) ** 2 / self.double_sigma_pw2)
+ heatmaps = np.matmul(exp_y.reshape(self.num_maps, self.height, 1), exp_x.reshape(self.num_maps, 1, self.width))
+
+ if self.norm:
+ heatmaps = heatmaps/self.doublepi_sigma_pw2
+
+ sample['heatmap2D'] = heatmaps
+ return sample
diff --git a/SPIGA/spiga/data/loaders/augmentors/landmarks.py b/SPIGA/spiga/data/loaders/augmentors/landmarks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1d17dcf9b86bf183bfe974b305e5fc0f6ea6aab
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/augmentors/landmarks.py
@@ -0,0 +1,307 @@
+import random
+import cv2
+import numpy as np
+from PIL import Image
+from torchvision import transforms
+
+# My libs
+import spiga.data.loaders.augmentors.utils as dlu
+
+
+class HorizontalFlipAug:
+ def __init__(self, ldm_flip_order, prob=0.5):
+ self.prob = prob
+ self.ldm_flip_order = ldm_flip_order
+
+ def __call__(self, sample):
+ img = sample['image']
+ landmarks = sample['landmarks']
+ mask = sample['mask_ldm']
+ vis = sample['visible']
+ bbox = sample['bbox']
+
+ if random.random() < self.prob:
+ new_img = transforms.functional.hflip(img)
+
+ lm_new_order = self.ldm_flip_order
+ new_landmarks = landmarks[lm_new_order]
+ new_landmarks = (new_landmarks - (img.size[0], 0)) * (-1, 1)
+ new_mask = mask[lm_new_order]
+ new_vis = vis[lm_new_order]
+
+ x, y, w, h = bbox
+ new_x = img.size[0] - x - w
+ new_bbox = np.array((new_x, y, w, h))
+
+ sample['image'] = new_img
+ sample['landmarks'] = new_landmarks
+ sample['mask_ldm'] = new_mask
+ sample['visible'] = new_vis
+ sample['bbox'] = new_bbox
+
+ return sample
+
+
+class GeometryBaseAug:
+
+ def __call__(self, sample):
+ raise NotImplementedError('Inheritance __call__ not defined')
+
+ def map_affine_transformation(self, sample, affine_transf, new_size=None):
+ sample['image'] = self._image_affine_trans(sample['image'], affine_transf, new_size)
+ sample['bbox'] = self._bbox_affine_trans(sample['bbox'], affine_transf)
+ if 'landmarks' in sample.keys():
+ sample['landmarks'] = self._landmarks_affine_trans(sample['landmarks'], affine_transf)
+ return sample
+
+ def clean_outbbox_landmarks(self, shape, landmarks, mask):
+ filter_x1 = landmarks[:, 0] >= shape[0]
+ filter_x2 = landmarks[:, 0] < (shape[0] + shape[2])
+ filter_x = np.logical_and(filter_x1,filter_x2)
+
+ filter_y1 = landmarks[:, 1] >= shape[1]
+ filter_y2 = landmarks[:, 1] < (shape[1] + shape[3])
+ filter_y = np.logical_and(filter_y1, filter_y2)
+
+ filter_bbox = np.logical_and(filter_x, filter_y)
+ new_mask = mask*filter_bbox
+ new_landmarks = (landmarks.T * new_mask).T
+ new_landmarks = new_landmarks.astype(int).astype(float)
+ return new_mask, new_landmarks
+
+ def _image_affine_trans(self, image, affine_transf, new_size=None):
+
+ if not new_size:
+ new_size = image.size
+
+ inv_affine_transf = dlu.get_inverse_transf(affine_transf)
+ new_image = image.transform(new_size, Image.AFFINE, inv_affine_transf.flatten())
+ return new_image
+
+ def _bbox_affine_trans(self, bbox, affine_transf):
+
+ x, y, w, h = bbox
+ images_bb = []
+ for point in ([x, y, 1], [x + w, y, 1],
+ [x, y + h, 1], [x + w, y + h, 1]):
+ images_bb.append(affine_transf.dot(point))
+ images_bb = np.array(images_bb)
+
+ new_corner0 = np.min(images_bb, axis=0)
+ new_corner1 = np.max(images_bb, axis=0)
+
+ new_x, new_y = new_corner0
+ new_w, new_h = new_corner1 - new_corner0
+ new_bbox = np.array((new_x, new_y, new_w, new_h))
+ return new_bbox
+
+ def _landmarks_affine_trans(self, landmarks, affine_transf):
+
+ homog_landmarks = dlu.affine2homogeneous(landmarks)
+ new_landmarks = affine_transf.dot(homog_landmarks.T).T
+ return new_landmarks
+
+
+class RSTAug(GeometryBaseAug):
+
+ def __init__(self, angle_range=45., scale_min=-0.15, scale_max=0.15, trl_ratio=0.05):
+ self.scale_max = scale_max
+ self.scale_min = scale_min
+ self.angle_range = angle_range
+ self.trl_ratio = trl_ratio
+
+ def __call__(self, sample):
+ x, y, w, h = sample['bbox']
+
+ x0, y0 = x + w/2, y + h/2 # center of the face, which will be the center of the rotation
+
+ # Bbox translation
+ rnd_Tx = np.random.uniform(-self.trl_ratio, self.trl_ratio) * w
+ rnd_Ty = np.random.uniform(-self.trl_ratio, self.trl_ratio) * h
+ sample['bbox'][0] += rnd_Tx
+ sample['bbox'][1] += rnd_Ty
+
+ scale = 1 + np.random.uniform(self.scale_min, self.scale_max)
+ angle = np.random.uniform(-self.angle_range, self.angle_range)
+
+ similarity = dlu.get_similarity_matrix(angle, scale, center=(x0, y0))
+ new_sample = self.map_affine_transformation(sample, similarity)
+ return new_sample
+
+
+class TargetCropAug(GeometryBaseAug):
+ def __init__(self, img_new_size=128, map_new_size=128, target_dist=1.3):
+
+ self.target_dist = target_dist
+ self.new_size_x, self.new_size_y = self._convert_shapes(img_new_size)
+ self.map_size_x, self.map_size_y = self._convert_shapes(map_new_size)
+ self.img2map_scale = False
+
+ # Mismatch between img shape and featuremap shape
+ if self.map_size_x != self.new_size_x or self.map_size_y != self.new_size_y:
+ self.img2map_scale = True
+ self.map_scale_x = self.map_size_x / self.new_size_x
+ self.map_scale_y = self.map_size_y / self.new_size_y
+ self.map_scale_xx = self.map_scale_x * self.map_scale_x
+ self.map_scale_xy = self.map_scale_x * self.map_scale_y
+ self.map_scale_yy = self.map_scale_y * self.map_scale_y
+
+ def _convert_shapes(self, new_size):
+ if isinstance(new_size, (tuple, list)):
+ new_size_x = new_size[0]
+ new_size_y = new_size[1]
+ else:
+ new_size_x = new_size
+ new_size_y = new_size
+ return new_size_x, new_size_y
+
+ def __call__(self, sample):
+ x, y, w, h = sample['bbox']
+ # we enlarge the area taken around the bounding box
+ # it is neccesary to change the botton left point of the bounding box
+ # according to the previous enlargement. Note this will NOT be the new
+ # bounding box!
+ # We return square images, which is neccesary since
+ # all the images must have the same size in order to form batches
+ side = max(w, h) * self.target_dist
+ x -= (side - w) / 2
+ y -= (side - h) / 2
+
+ # center of the enlarged bounding box
+ x0, y0 = x + side/2, y + side/2
+ # homothety factor, chosen so the new horizontal dimension will
+ # coincide with new_size
+ mu_x = self.new_size_x / side
+ mu_y = self.new_size_y / side
+
+ # new_w, new_h = new_size, int(h * mu)
+ new_w = self.new_size_x
+ new_h = self.new_size_y
+ new_x0, new_y0 = new_w / 2, new_h / 2
+
+ # dilatation + translation
+ affine_transf = np.array([[mu_x, 0, new_x0 - mu_x * x0],
+ [0, mu_y, new_y0 - mu_y * y0]])
+
+ sample = self.map_affine_transformation(sample, affine_transf,(new_w, new_h))
+ if 'landmarks' in sample.keys():
+ img_shape = np.array([0, 0, self.new_size_x, self.new_size_y])
+ sample['landmarks_float'] = sample['landmarks']
+ sample['mask_ldm_float'] = sample['mask_ldm']
+ sample['landmarks'] = np.round(sample['landmarks'])
+ sample['mask_ldm'], sample['landmarks'] = self.clean_outbbox_landmarks(img_shape, sample['landmarks'],
+ sample['mask_ldm'])
+
+ if self.img2map_scale:
+ sample = self._rescale_map(sample)
+ return sample
+
+ def _rescale_map(self, sample):
+
+ # Rescale
+ lnd_float = sample['landmarks_float']
+ lnd_float[:, 0] = self.map_scale_x * lnd_float[:, 0]
+ lnd_float[:, 1] = self.map_scale_y * lnd_float[:, 1]
+
+ # Filter landmarks
+ lnd = np.round(lnd_float)
+ filter_x = lnd[:, 0] >= self.map_size_x
+ filter_y = lnd[:, 1] >= self.map_size_y
+ lnd[filter_x] = self.map_size_x - 1
+ lnd[filter_y] = self.map_size_y - 1
+ new_lnd = (lnd.T * sample['mask_ldm']).T
+ new_lnd = new_lnd.astype(int).astype(float)
+
+ sample['landmarks_float'] = lnd_float
+ sample['landmarks'] = new_lnd
+ sample['img2map_scale'] = [self.map_scale_x, self.map_scale_y]
+ return sample
+
+
+
+class OcclusionAug:
+ def __init__(self, min_length=0.1, max_length=0.4, num_maps=1):
+ self.min_length = min_length
+ self.max_length = max_length
+ self.num_maps = num_maps
+
+ def __call__(self, sample):
+ x, y, w, h = sample['bbox']
+ image = sample['image']
+ landmarks = sample['landmarks']
+ vis = sample['visible']
+
+ min_ratio = self.min_length
+ max_ratio = self.max_length
+ rnd_width = np.random.randint(int(w * min_ratio), int(w * max_ratio))
+ rnd_height = np.random.randint(int(h * min_ratio), int(h * max_ratio))
+
+ # (xi, yi) and (xf, yf) are, respectively, the lower left points of the
+ # occlusion rectangle and the upper right point.
+ xi = int(x + np.random.randint(0, w - rnd_width))
+ xf = int(xi + rnd_width)
+ yi = int(y + np.random.randint(0, h - rnd_height))
+ yf = int(yi + rnd_height)
+
+ pixels = np.array(image)
+ pixels[yi:yf, xi:xf, :] = np.random.uniform(0, 255, size=3)
+ image = Image.fromarray(pixels)
+ sample['image'] = image
+
+ # Update visibilities
+ filter_x1 = landmarks[:, 0] >= xi
+ filter_x2 = landmarks[:, 0] < xf
+ filter_x = np.logical_and(filter_x1, filter_x2)
+
+ filter_y1 = landmarks[:, 1] >= yi
+ filter_y2 = landmarks[:, 1] < yf
+ filter_y = np.logical_and(filter_y1, filter_y2)
+
+ filter_novis = np.logical_and(filter_x, filter_y)
+ filter_vis = np.logical_not(filter_novis)
+ sample['visible'] = vis * filter_vis
+ return sample
+
+
+class LightingAug:
+ def __init__(self, hsv_range_min=(-0.5, -0.5, -0.5), hsv_range_max=(0.5, 0.5, 0.5)):
+ self.hsv_range_min = hsv_range_min
+ self.hsv_range_max = hsv_range_max
+
+ def __call__(self, sample):
+ # Convert to HSV colorspace from RGB colorspace
+ image = np.array(sample['image'])
+ hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
+
+ # Generate new random values
+ H = 1 + np.random.uniform(self.hsv_range_min[0], self.hsv_range_max[0])
+ S = 1 + np.random.uniform(self.hsv_range_min[1], self.hsv_range_max[1])
+ V = 1 + np.random.uniform(self.hsv_range_min[2], self.hsv_range_max[2])
+ hsv[:, :, 0] = np.clip(H*hsv[:, :, 0], 0, 179)
+ hsv[:, :, 1] = np.clip(S*hsv[:, :, 1], 0, 255)
+ hsv[:, :, 2] = np.clip(V*hsv[:, :, 2], 0, 255)
+ # Convert back to BGR colorspace
+ image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
+ sample['image'] = Image.fromarray(image)
+
+ return sample
+
+
+class BlurAug:
+ def __init__(self, blur_prob=0.5, blur_kernel_range=(0, 2)):
+ self.blur_prob = blur_prob
+ self.kernel_range = blur_kernel_range
+
+ def __call__(self, sample):
+ # Smooth image
+ image = np.array(sample['image'])
+ if np.random.uniform(0.0, 1.0) < self.blur_prob:
+ kernel = np.random.random_integers(self.kernel_range[0], self.kernel_range[1]) * 2 + 1
+ image = cv2.GaussianBlur(image, (kernel, kernel), 0, 0)
+ sample['image'] = Image.fromarray(image)
+
+ return sample
+
+
+
+
diff --git a/SPIGA/spiga/data/loaders/augmentors/modern_posit.py b/SPIGA/spiga/data/loaders/augmentors/modern_posit.py
new file mode 100644
index 0000000000000000000000000000000000000000..34042c80ccfa78b92c4a1a59c9053d4ed3ee5a65
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/augmentors/modern_posit.py
@@ -0,0 +1,197 @@
+import os
+import pkg_resources
+import numpy as np
+import cv2
+
+# My libs
+from spiga.data.loaders.augmentors.utils import rotation_matrix_to_euler
+
+# Model file nomenclature
+model_file_dft = pkg_resources.resource_filename('spiga', 'data/models3D') + '/mean_face_3D_{num_ldm}.txt'
+
+
+class PositPose:
+
+ def __init__(self, ldm_ids, focal_ratio=1, selected_ids=None, max_iter=100,
+ fix_bbox=True, model_file=model_file_dft):
+
+ # Load 3D face model
+ model3d_world, model3d_ids = self._load_world_shape(ldm_ids, model_file)
+
+ # Generate id mask to pick only the robust landmarks for posit
+ if selected_ids is None:
+ model3d_mask = np.ones(len(ldm_ids))
+ else:
+ model3d_mask = np.zeros(len(ldm_ids))
+ for index, posit_id in enumerate(model3d_ids):
+ if posit_id in selected_ids:
+ model3d_mask[index] = 1
+
+ self.ldm_ids = ldm_ids # Ids from the database
+ self.model3d_world = model3d_world # Model data
+ self.model3d_ids = model3d_ids # Model ids
+ self.model3d_mask = model3d_mask # Model mask ids
+ self.max_iter = max_iter # Refinement iterations
+ self.focal_ratio = focal_ratio # Camera matrix focal length ratio
+ self.fix_bbox = fix_bbox # Camera matrix centered on image (False to centered on bbox)
+
+ def __call__(self, sample):
+
+ landmarks = sample['landmarks']
+ mask = sample['mask_ldm']
+
+ # Camera matrix
+ img_shape = np.array(sample['image'].shape)[0:2]
+ if 'img2map_scale' in sample.keys():
+ img_shape = img_shape * sample['img2map_scale']
+
+ if self.fix_bbox:
+ img_bbox = [0, 0, img_shape[1], img_shape[0]] # Shapes given are inverted (y,x)
+ cam_matrix = self._camera_matrix(img_bbox)
+ else:
+ bbox = sample['bbox'] # Scale error when ftshape and img_shape mismatch
+ cam_matrix = self._camera_matrix(bbox)
+
+ # Save intrinsic matrix and 3D model landmarks
+ sample['cam_matrix'] = cam_matrix
+ sample['model3d'] = self.model3d_world
+
+ world_pts, image_pts = self._set_correspondences(landmarks, mask)
+
+ if image_pts.shape[0] < 4:
+ print('POSIT does not work without landmarks')
+ rot_matrix, trl_matrix = np.eye(3, dtype=float), np.array([0, 0, 0])
+ else:
+ rot_matrix, trl_matrix = self._modern_posit(world_pts, image_pts, cam_matrix)
+
+ euler = rotation_matrix_to_euler(rot_matrix)
+ sample['pose'] = np.array([euler[0], euler[1], euler[2], trl_matrix[0], trl_matrix[1], trl_matrix[2]])
+ sample['model3d_proj'] = self._project_points(rot_matrix, trl_matrix, cam_matrix, norm=img_shape)
+ return sample
+
+ def _load_world_shape(self, ldm_ids, model_file):
+ return load_world_shape(ldm_ids, model_file=model_file)
+
+ def _camera_matrix(self, bbox):
+ focal_length_x = bbox[2] * self.focal_ratio
+ focal_length_y = bbox[3] * self.focal_ratio
+ face_center = (bbox[0] + (bbox[2] * 0.5)), (bbox[1] + (bbox[3] * 0.5))
+
+ cam_matrix = np.array([[focal_length_x, 0, face_center[0]],
+ [0, focal_length_y, face_center[1]],
+ [0, 0, 1]])
+ return cam_matrix
+
+ def _set_correspondences(self, landmarks, mask):
+ # Correspondences using labelled and robust landmarks
+ img_mask = np.logical_and(mask, self.model3d_mask)
+ img_mask = img_mask.astype(bool)
+
+ image_pts = landmarks[img_mask]
+ world_pts = self.model3d_world[img_mask]
+ return world_pts, image_pts
+
+ def _modern_posit(self, world_pts, image_pts, cam_matrix):
+ return modern_posit(world_pts, image_pts, cam_matrix, self.max_iter)
+
+ def _project_points(self, rot, trl, cam_matrix, norm=None):
+ # Perspective projection model
+ trl = np.expand_dims(trl, 1)
+ extrinsics = np.concatenate((rot, trl), 1)
+ proj_matrix = np.matmul(cam_matrix, extrinsics)
+
+ # Homogeneous landmarks
+ pts = self.model3d_world
+ ones = np.ones(pts.shape[0])
+ ones = np.expand_dims(ones, 1)
+ pts_hom = np.concatenate((pts, ones), 1)
+
+ # Project landmarks
+ pts_proj = np.matmul(proj_matrix, pts_hom.T).T
+ pts_proj = pts_proj / np.expand_dims(pts_proj[:, 2], 1) # Lambda = 1
+
+ if norm is not None:
+ pts_proj[:, 0] /= norm[0]
+ pts_proj[:, 1] /= norm[1]
+ return pts_proj[:, :-1]
+
+
+def load_world_shape(db_landmarks, model_file=model_file_dft):
+
+ # Load 3D mean face coordinates
+ num_ldm = len(db_landmarks)
+ filename = model_file.format(num_ldm=num_ldm)
+ if not os.path.exists(filename):
+ raise ValueError('No 3D model find for %i landmarks' % num_ldm)
+
+ posit_landmarks = np.genfromtxt(filename, delimiter='|', dtype=int, usecols=0).tolist()
+ mean_face_3D = np.genfromtxt(filename, delimiter='|', dtype=(float, float, float), usecols=(1, 2, 3)).tolist()
+ world_all = len(mean_face_3D)*[None]
+ index_all = len(mean_face_3D)*[None]
+
+ for cont, elem in enumerate(mean_face_3D):
+ pt3d = [elem[2], -elem[0], -elem[1]]
+ lnd_idx = db_landmarks.index(posit_landmarks[cont])
+ world_all[lnd_idx] = pt3d
+ index_all[lnd_idx] = posit_landmarks[cont]
+
+ return np.array(world_all), np.array(index_all)
+
+
+def modern_posit(world_pts, image_pts, cam_matrix, max_iters):
+ # Homogeneous world points
+ num_landmarks = image_pts.shape[0]
+ one = np.ones((num_landmarks, 1))
+ A = np.concatenate((world_pts, one), axis=1)
+ B = np.linalg.pinv(A)
+
+ # Normalize image points
+ focal_length = cam_matrix[0,0]
+ img_center = (cam_matrix[0,2], cam_matrix[1,2])
+ centered_pts = np.zeros((num_landmarks,2))
+ centered_pts[:,0] = (image_pts[:,0]-img_center[0])/focal_length
+ centered_pts[:,1] = (image_pts[:,1]-img_center[1])/focal_length
+ Ui = centered_pts[:,0]
+ Vi = centered_pts[:,1]
+
+ # POSIT loop
+ Tx, Ty, Tz = 0.0, 0.0, 0.0
+ r1, r2, r3 = [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]
+ for iter in range(0, max_iters):
+ I = np.dot(B,Ui)
+ J = np.dot(B,Vi)
+
+ # Estimate translation vector and rotation matrix
+ normI = 1.0 / np.sqrt(I[0] * I[0] + I[1] * I[1] + I[2] * I[2])
+ normJ = 1.0 / np.sqrt(J[0] * J[0] + J[1] * J[1] + J[2] * J[2])
+ Tz = np.sqrt(normI * normJ) # geometric average instead of arithmetic average of classicPosit
+ r1N = I*Tz
+ r2N = J*Tz
+ r1 = r1N[0:3]
+ r2 = r2N[0:3]
+ r1 = np.clip(r1, -1, 1)
+ r2 = np.clip(r2, -1, 1)
+ r3 = np.cross(r1,r2)
+ r3T = np.concatenate((r3, [Tz]), axis=0)
+ Tx = r1N[3]
+ Ty = r2N[3]
+
+ # Compute epsilon, update Ui and Vi and check convergence
+ eps = np.dot(A, r3T)/Tz
+ oldUi = Ui
+ oldVi = Vi
+ Ui = np.multiply(eps, centered_pts[:,0])
+ Vi = np.multiply(eps, centered_pts[:,1])
+ deltaUi = Ui - oldUi
+ deltaVi = Vi - oldVi
+ delta = focal_length * focal_length * (np.dot(np.transpose(deltaUi), deltaUi) + np.dot(np.transpose(deltaVi), deltaVi))
+ if iter > 0 and delta < 0.01: # converged
+ break
+
+ rot_matrix = np.array([np.transpose(r1), np.transpose(r2), np.transpose(r3)])
+ trl_matrix = np.array([Tx, Ty, Tz])
+ # Convert to the nearest orthogonal rotation matrix
+ w, u, vt = cv2.SVDecomp(rot_matrix) # R = U*D*Vt
+ rot_matrix = np.matmul(np.matmul(u, np.eye(3, dtype=float)), vt)
+ return rot_matrix, trl_matrix
+
diff --git a/SPIGA/spiga/data/loaders/augmentors/utils.py b/SPIGA/spiga/data/loaders/augmentors/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..30fcc012790a7bb984dc8ceefbeedc2b8a234465
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/augmentors/utils.py
@@ -0,0 +1,147 @@
+import numpy as np
+
+
+def affine2homogeneous(points):
+ '''Returns the points completed with a new last coordinate
+ equal to 1
+ Arguments
+ ---------
+ points: np.array of shape (num_points, dim)
+ Returns
+ -------
+ hpoints: np.array of shape (num_points, dim + 1),
+ of the points completed with ones'''
+
+ num_points = points.shape[0]
+ hpoints = np.hstack(
+ (points, np.repeat(1, num_points).reshape(num_points, 1)))
+ return hpoints
+
+
+def get_similarity_matrix(deg_angle, scale, center):
+ '''Similarity matrix.
+ Arguments:
+ ---------
+ deg_angle: rotation angle in degrees
+ scale: factor scale
+ center: coordinates of the rotation center
+
+ Returns:
+ -------
+ matrix: (2, 3) numpy array representing the
+ similarity matrix.
+ '''
+ x0, y0 = center
+ angle = np.radians(deg_angle)
+
+ matrix = np.zeros((2, 3))
+ matrix[0:2, 0:2] = [[np.cos(angle), -np.sin(angle)],
+ [np.sin(angle), np.cos(angle)]]
+ matrix[0: 2, 0: 2] *= scale
+
+ matrix[:, 2] = [(1 - scale * np.cos(angle)) * x0 +
+ scale * np.sin(angle) * y0,
+ -scale * np.sin(angle) * x0 +
+ (1 - scale * np.cos(angle)) * y0]
+ return matrix
+
+
+def get_inverse_similarity_matrix(deg_angle, scale, center):
+ '''Returns the inverse of the affine similarity
+ Arguments
+ ---------
+ deg_angle: angle in degrees of the rotation
+ center: iterable of two components (x0, y0),
+ center of the rotation
+ scale: float, scale factor
+ Returns
+ -------
+ matrix: np.array of shape (2, 3) with the coordinates of
+ the inverse of the similarity'''
+
+ x0, y0 = center
+ angle = np.radians(deg_angle)
+ inv_scale = 1 / scale
+ matrix = np.zeros((2, 3))
+ matrix[0:2, 0:2] = [[np.cos(angle), np.sin(angle)],
+ [-np.sin(angle), np.cos(angle)]]
+ matrix[0:2, 0:2] *= inv_scale
+
+ matrix[:, 2] = [(1 - inv_scale * np.cos(angle)) * x0 -
+ inv_scale * np.sin(angle) * y0,
+ inv_scale * np.sin(angle) * x0 +
+ (1 - inv_scale * np.cos(angle)) * y0]
+
+ return matrix
+
+
+def get_inverse_transf(affine_transf):
+ A = affine_transf[0:2, 0:2]
+ b = affine_transf[:, 2]
+
+ inv_A = np.linalg.inv(A) # we assume A invertible!
+
+ inv_affine = np.zeros((2, 3))
+ inv_affine[0:2, 0:2] = inv_A
+ inv_affine[:, 2] = -inv_A.dot(b)
+
+ return inv_affine
+
+
+def image2vect(image):
+ '''
+ Input:
+ image[batch_size, num_channels, im_size_x, im_size_y]
+ Output:
+ vect[batch_size, num_channels, im_size_x*im_size_y]
+ '''
+ vect = image.reshape(*image.shape[0:-2], -1)
+ return vect
+
+
+def rotation_matrix_to_euler(rot_matrix):
+ # http://euclideanspace.com/maths/geometry/rotations/conversions/matrixToEuler/index.htm
+ a00, a01, a02 = rot_matrix[0, 0], rot_matrix[0, 1], rot_matrix[0, 2]
+ a10, a11, a12 = rot_matrix[1, 0], rot_matrix[1, 1], rot_matrix[1, 2]
+ a20, a21, a22 = rot_matrix[2, 0], rot_matrix[2, 1], rot_matrix[2, 2]
+ if abs(1.0 - a10) <= np.finfo(float).eps: # singularity at north pole / special case a10 == 1
+ yaw = np.arctan2(a02, a22)
+ pitch = np.pi/2.0
+ roll = 0
+ elif abs(-1.0 - a10) <= np.finfo(float).eps: # singularity at south pole / special case a10 == -1
+ yaw = np.arctan2(a02, a22)
+ pitch = -np.pi/2.0
+ roll = 0
+ else: # standard case
+ yaw = np.arctan2(-a20, a00)
+ pitch = np.arcsin(a10)
+ roll = np.arctan2(-a12, a11)
+ # Convert to degrees
+ euler = np.array([yaw, pitch, roll])*(180.0/np.pi)
+ # Change coordinates system
+ euler = np.array([(-euler[0])+90, -euler[1], (-euler[2])-90])
+ if euler[0] > 180: euler[0] -= 360
+ elif euler[0] < -180: euler[0] += 360
+ if euler[1] > 180: euler[1] -= 360
+ elif euler[1] < -180: euler[1] += 360
+ if euler[2] > 180: euler[2] -= 360
+ elif euler[2] < -180: euler[2] += 360
+ return euler
+
+
+def euler_to_rotation_matrix(headpose):
+ # http://euclideanspace.com/maths/geometry/rotations/conversions/eulerToMatrix/index.htm
+ # Change coordinates system
+ euler = np.array([-(headpose[0]-90), -headpose[1], -(headpose[2]+90)])
+ # Convert to radians
+ rad = euler*(np.pi/180.0)
+ cy = np.cos(rad[0])
+ sy = np.sin(rad[0])
+ cp = np.cos(rad[1])
+ sp = np.sin(rad[1])
+ cr = np.cos(rad[2])
+ sr = np.sin(rad[2])
+ Ry = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]]) # yaw
+ Rp = np.array([[cp, -sp, 0.0], [sp, cp, 0.0], [0.0, 0.0, 1.0]]) # pitch
+ Rr = np.array([[1.0, 0.0, 0.0], [0.0, cr, -sr], [0.0, sr, cr]]) # roll
+ return np.matmul(np.matmul(Ry, Rp), Rr)
diff --git a/SPIGA/spiga/data/loaders/dataloader.py b/SPIGA/spiga/data/loaders/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..f942510916fffb132762ab5e41d5ac96d6b54e7a
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/dataloader.py
@@ -0,0 +1,41 @@
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+import spiga.data.loaders.alignments as zoo_alignments
+
+zoos = [zoo_alignments]
+
+
+def get_dataset(data_config, pretreat=None, debug=False):
+
+ for zoo in zoos:
+ dataset = zoo.get_dataset(data_config, pretreat=pretreat, debug=debug)
+ if dataset is not None:
+ return dataset
+ raise NotImplementedError('Dataset not available')
+
+
+def get_dataloader(batch_size, data_config, pretreat=None, sampler_cfg=None, debug=False):
+
+ dataset = get_dataset(data_config, pretreat=pretreat, debug=debug)
+
+ if (len(dataset) % batch_size) == 1 and data_config.shuffle == True:
+ drop_last_batch = True
+ else:
+ drop_last_batch = False
+
+ shuffle = data_config.shuffle
+ sampler = None
+ if sampler_cfg is not None:
+ sampler = DistributedSampler(dataset, num_replicas=sampler_cfg.world_size, rank=sampler_cfg.rank)
+ shuffle = False
+
+ dataloader = DataLoader(dataset,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ num_workers=data_config.num_workers,
+ pin_memory=True,
+ drop_last=drop_last_batch,
+ sampler=sampler)
+
+ return dataloader, dataset
diff --git a/SPIGA/spiga/data/loaders/dl_config.py b/SPIGA/spiga/data/loaders/dl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..962f883d6fb299d66b454366cbd906e2ae32d8d1
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/dl_config.py
@@ -0,0 +1,170 @@
+import os
+import json
+import pkg_resources
+from collections import OrderedDict
+
+# Default data paths
+db_img_path = pkg_resources.resource_filename('spiga', 'data/databases')
+db_anns_path = pkg_resources.resource_filename('spiga', 'data/annotations') + "/{database}/{file_name}.json"
+
+class AlignConfig:
+
+ def __init__(self, database_name, mode='train'):
+ # Dataset
+ self.database_name = database_name
+ self.working_mode = mode
+ self.database = None # Set at self._update_database()
+ self.anns_file = None # Set at self._update_database()
+ self.image_dir = None # Set at self._update_database()
+ self._update_database()
+ self.image_size = (256, 256)
+ self.ftmap_size = (256, 256)
+
+ # Dataloaders
+ self.ids = None # List of a subset if need it
+ self.shuffle = True # Shuffle samples
+ self.num_workers = 4 # Threads
+
+ # Posit
+ self.generate_pose = True # Generate pose parameters from landmarks
+ self.focal_ratio = 1.5 # Camera matrix focal length ratio
+ self.posit_max_iter = 100 # Refinement iterations
+
+ # Subset of robust ids in the 3D model to use in posit.
+ # 'None' to use all the available model landmarks.
+ self.posit_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
+
+ # Data augmentation
+ # Control augmentations with the following list, crop to self.img_size is mandatory, check target_dist param.
+ if mode == 'train':
+ self.aug_names = ['flip', 'rotate_scale', 'occlusion', 'lighting', 'blur']
+ else:
+ self.aug_names = []
+ self.shuffle = False
+
+ # Flip
+ self.hflip_prob = 0.5
+ # Rotation
+ self.angle_range = 45.
+ # Scale
+ self.scale_max = 0.15
+ self.scale_min = -0.15
+ # Translation
+ self.trl_ratio = 0.05 # Translation augmentation
+ # Crop target rescale
+ self.target_dist = 1.6 # Target distance zoom in/out around face. Default: 1.
+ # Occlusion
+ self.occluded_max_len = 0.4
+ self.occluded_min_len = 0.1
+ self.occluded_covar_ratio = 2.25**0.5
+ # Lighting
+ self.hsv_range_min = [-0.5, -0.5, -0.5]
+ self.hsv_range_max = [0.5, 0.5, 0.5]
+ # Blur
+ self.blur_prob = 0.5
+ self.blur_kernel_range = [0, 2]
+
+ # Heatmaps 2D
+ self.sigma2D = 1.5
+ self.heatmap2D_norm = False
+
+ # Boundaries
+ self.sigmaBD = 1
+
+ def update(self, params_dict):
+ state_dict = self.state_dict()
+ for k, v in params_dict.items():
+ if k in state_dict or hasattr(self, k):
+ setattr(self, k, v)
+ else:
+ Warning('Unknown option: {}: {}'.format(k, v))
+ self._update_database()
+
+ def state_dict(self, tojson=False):
+ state_dict = OrderedDict()
+ for k in self.__dict__.keys():
+ if not k.startswith('_'):
+ if tojson and k in ['database']:
+ continue
+ state_dict[k] = getattr(self, k)
+ return state_dict
+
+ def _update_database(self):
+ self.database = DatabaseStruct(self.database_name)
+ self.anns_file = db_anns_path.format(database=self.database_name, file_name=self.working_mode)
+ self.image_dir = self._get_imgdb_path()
+
+ def _get_imgdb_path(self):
+ img_dir = None
+ if self.database_name in ['300wpublic', '300wprivate']:
+ img_dir = db_img_path + '/300w/'
+ elif self.database_name in ['aflw19', 'merlrav']:
+ img_dir = db_img_path + '/aflw/data/'
+ elif self.database_name in ['cofw', 'cofw68']:
+ img_dir = db_img_path + '/cofw/'
+ elif self.database_name in ['wflw']:
+ img_dir = db_img_path + '/wflw/'
+ return img_dir
+
+ def __str__(self):
+ state_dict = self.state_dict()
+ text = 'Dataloader {\n'
+ for k, v in state_dict.items():
+ if isinstance(v, DatabaseStruct):
+ text += '\t{}: {}'.format(k, str(v).expandtabs(12))
+ else:
+ text += '\t{}: {}\n'.format(k, v)
+ text += '\t}\n'
+ return text
+
+
+class DatabaseStruct:
+
+ def __init__(self, database_name):
+
+ self.name = database_name
+ self.ldm_ids, self.ldm_flip_order, self.ldm_edges_matrix = self._get_database_specifics()
+ self.num_landmarks = len(self.ldm_ids)
+ self.num_edges = len(self.ldm_edges_matrix[0])-1
+ self.fields = ['imgpath', 'bbox', 'headpose', 'ids', 'landmarks', 'visible']
+
+ def _get_database_specifics(self):
+ '''Returns specifics ids and horizontal flip reorder'''
+
+ database_name = self.name
+ db_info_file = db_anns_path.format(database=database_name, file_name='db_info')
+ ldm_edges_matrix = None
+
+ if os.path.exists(db_info_file):
+ with open(db_info_file) as jsonfile:
+ db_info = json.load(jsonfile)
+
+ ldm_ids = db_info['ldm_ids']
+ ldm_flip_order = db_info['ldm_flip_order']
+ if 'ldm_edges_matrix' in db_info.keys():
+ ldm_edges_matrix = db_info['ldm_edges_matrix']
+
+ else:
+ raise ValueError('Database ' + database_name + 'specifics not defined. Missing db_info.json')
+
+ return ldm_ids, ldm_flip_order, ldm_edges_matrix
+
+ def state_dict(self):
+ state_dict = OrderedDict()
+ for k in self.__dict__.keys():
+ if not k.startswith('_'):
+ state_dict[k] = getattr(self, k)
+
+ return state_dict
+
+ def __str__(self):
+ state_dict = self.state_dict()
+ text = 'Database {\n'
+ for k, v in state_dict.items():
+ text += '\t{}: {}\n'.format(k, v)
+ text += '\t}\n'
+ return text
+
+
+
diff --git a/SPIGA/spiga/data/loaders/transforms.py b/SPIGA/spiga/data/loaders/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..6371657516cd766c363fd7ea51bf978a65c70376
--- /dev/null
+++ b/SPIGA/spiga/data/loaders/transforms.py
@@ -0,0 +1,84 @@
+import cv2
+import numpy as np
+import torch
+
+from spiga.data.loaders.augmentors.modern_posit import PositPose
+from spiga.data.loaders.augmentors.heatmaps import Heatmaps
+from spiga.data.loaders.augmentors.boundary import AddBoundary
+from spiga.data.loaders.augmentors.landmarks import HorizontalFlipAug, RSTAug, OcclusionAug, \
+ LightingAug, BlurAug, TargetCropAug
+
+
+def get_transformers(data_config):
+
+ # Data augmentation
+ aug_names = data_config.aug_names
+ augmentors = []
+
+ if 'flip' in aug_names:
+ augmentors.append(HorizontalFlipAug(data_config.database.ldm_flip_order, data_config.hflip_prob))
+ if 'rotate_scale' in aug_names:
+ augmentors.append(RSTAug(data_config.angle_range, data_config.scale_min,
+ data_config.scale_max, data_config.trl_ratio))
+ if 'occlusion' in aug_names:
+ augmentors.append(OcclusionAug(data_config.occluded_min_len,
+ data_config.occluded_max_len,
+ data_config.database.num_landmarks))
+ if 'lighting' in aug_names:
+ augmentors.append(LightingAug(data_config.hsv_range_min, data_config.hsv_range_max))
+ if 'blur' in aug_names:
+ augmentors.append(BlurAug(data_config.blur_prob, data_config.blur_kernel_range))
+
+ # Crop mandatory
+ augmentors.append(TargetCropAug(data_config.image_size, data_config.ftmap_size, data_config.target_dist))
+ # Opencv style
+ augmentors.append(ToOpencv())
+
+ # Gaussian heatmaps
+ if 'heatmaps2D' in aug_names:
+ augmentors.append(Heatmaps(data_config.database.num_landmarks, data_config.ftmap_size,
+ data_config.sigma2D, norm=data_config.heatmap2D_norm))
+
+ if 'boundaries' in aug_names:
+ augmentors.append(AddBoundary(num_landmarks=data_config.database.num_landmarks,
+ map_size=data_config.ftmap_size,
+ sigma=data_config.sigmaBD))
+ # Pose generator
+ if data_config.generate_pose:
+ augmentors.append(PositPose(data_config.database.ldm_ids,
+ focal_ratio=data_config.focal_ratio,
+ selected_ids=data_config.posit_ids,
+ max_iter=data_config.posit_max_iter))
+
+ return augmentors
+
+
+class ToOpencv:
+ def __call__(self, sample):
+ # Convert in a numpy array and change to GBR
+ image = np.array(sample['image'])
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ sample['image'] = image
+ return sample
+
+
+class TargetCrop(TargetCropAug):
+ def __init__(self, crop_size=256, target_dist=1.6):
+ super(TargetCrop, self).__init__(crop_size, crop_size, target_dist)
+
+
+class AddModel3D(PositPose):
+ def __init__(self, ldm_ids, ftmap_size=(256, 256), focal_ratio=1.5, totensor=False):
+ super(AddModel3D, self).__init__(ldm_ids, focal_ratio=focal_ratio)
+ img_bbox = [0, 0, ftmap_size[1], ftmap_size[0]] # Shapes given are inverted (y,x)
+ self.cam_matrix = self._camera_matrix(img_bbox)
+
+ if totensor:
+ self.cam_matrix = torch.tensor(self.cam_matrix, dtype=torch.float)
+ self.model3d_world = torch.tensor(self.model3d_world, dtype=torch.float)
+
+ def __call__(self, sample={}):
+ # Save intrinsic matrix and 3D model landmarks
+ sample['cam_matrix'] = self.cam_matrix
+ sample['model3d'] = self.model3d_world
+ return sample
diff --git a/SPIGA/spiga/data/models3D/__init__.py b/SPIGA/spiga/data/models3D/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/data/models3D/mean_face_3D_68.txt b/SPIGA/spiga/data/models3D/mean_face_3D_68.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5640bd0df65fcb268cbc78c5d0b3eba6221f78c7
--- /dev/null
+++ b/SPIGA/spiga/data/models3D/mean_face_3D_68.txt
@@ -0,0 +1,68 @@
+101|-0.71046061493|0.39334543762|-0.853184236267
+102|-0.685409656726|0.169750713576|-0.878623094747
+103|-0.607672440834|-0.0597969416711|-0.922355056609
+104|-0.545664962607|-0.262499453743|-0.920836599126
+105|-0.446214526601|-0.439045175771|-0.863139209
+106|-0.334435950087|-0.566049889292|-0.728913281016
+107|-0.236168305748|-0.653698223145|-0.585459059996
+108|-0.128307367184|-0.712045321067|-0.488221730877
+24|-0.00354849146401|-0.73799619898|-0.461364289002
+110|0.131199109787|-0.733135962679|-0.466262199442
+111|0.257373771199|-0.666487934012|-0.560781626341
+112|0.359823738861|-0.569594984479|-0.666664889106
+113|0.455878939364|-0.475272786995|-0.762199581604
+114|0.583151236654|-0.255441185628|-0.895445263128
+115|0.635035270608|-0.076593055883|-0.925024981148
+116|0.688562921266|0.127095701882|-0.90624669312
+117|0.711883056933|0.332170073984|-0.901990426153
+7|-0.406302776083|0.360389456414|-0.383328256737
+138|-0.34477414508|0.384294131116|-0.310253684249
+139|-0.261521910706|0.396020592391|-0.286645173372
+8|-0.160787895282|0.371046071664|-0.330243237009
+141|-0.245621576158|0.334421719594|-0.323105980143
+142|-0.328092143154|0.331973054751|-0.339942994735
+11|0.163880511996|0.34074733643|-0.314140711047
+144|0.250844605212|0.392490504074|-0.279030578851
+145|0.340770354713|0.378980645804|-0.285841921386
+12|0.432352000574|0.348665008376|-0.377158168878
+147|0.333467711264|0.328033806608|-0.318530865652
+148|0.246621391603|0.335053488011|-0.313137199532
+1|-0.552168877878|0.483792334207|-0.466456539274
+119|-0.459810924979|0.561780416181|-0.308160933539
+2|-0.360178576451|0.566290707558|-0.216753374881
+121|-0.249770054109|0.530336745189|-0.148694799479
+3|-0.101941089981|0.482471777751|-0.124216132105
+128|0.00843526194116|0.416310824525|-0.139142361862
+129|0.00731217835113|0.267155736618|-0.106471830332
+130|0.00190174006432|0.142359799355|-0.0544539305898
+17|0.0|0.0|0.0
+16|-0.118903311783|-0.0792784542251|-0.23553779698
+133|-0.0504300242939|-0.104665185108|-0.171210015652
+134|0.00312657092554|-0.111845126622|-0.154139340027
+135|0.0687543983424|-0.0995458563329|-0.173275169304
+18|0.148988810614|-0.0737848514841|-0.23851346235
+4|0.0979764884322|0.465935806416|-0.131673602172
+124|0.214281256931|0.497130488239|-0.135680573937
+5|0.373011222167|0.520379035113|-0.223000042464
+126|0.501607216527|0.500371303274|-0.31925158046
+6|0.57362522379|0.457842172776|-0.465084060802
+20|-0.2229832011|-0.277627584817|-0.354306324758
+150|-0.157062857603|-0.248737641674|-0.270218764688
+151|-0.0816354856174|-0.23618498428|-0.225237980338
+22|-0.00391577759471|-0.24499812323|-0.2126098208
+153|0.0629421066443|-0.230774971512|-0.208501270489
+154|0.135507933036|-0.240646328001|-0.235471195707
+21|0.240381217797|-0.283709533358|-0.332281891845
+156|0.143706630542|-0.329461104518|-0.283978333058
+157|0.084139951946|-0.355354733249|-0.256464077311
+23|0.00813704018925|-0.360357397837|-0.250423014186
+159|-0.0778260690655|-0.349041413147|-0.26966136614
+160|-0.157111645515|-0.311971743234|-0.322383820174
+161|-0.182660017757|-0.278377925956|-0.341170846656
+162|-0.075943475389|-0.283519721316|-0.270854173882
+163|-0.0026701322545|-0.286917503123|-0.254428375696
+164|0.0685062003389|-0.285630836758|-0.257845545665
+165|0.194170296105|-0.279641061937|-0.314605542849
+166|0.0689862474833|-0.288076969827|-0.258696545259
+167|-0.00273913414372|-0.28915723747|-0.25622946708
+168|-0.0766989697362|-0.286311968999|-0.273027820051
diff --git a/SPIGA/spiga/data/models3D/mean_face_3D_98.txt b/SPIGA/spiga/data/models3D/mean_face_3D_98.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f5920400bffb6ecc8ae2afc39348ba63ba3805b1
--- /dev/null
+++ b/SPIGA/spiga/data/models3D/mean_face_3D_98.txt
@@ -0,0 +1,98 @@
+100|-0.710460614932|0.393345437621|-0.853184236268
+101|-0.697935135807|0.281548075586|-0.865903665506
+102|-0.685409656727|0.169750713576|-0.878623094748
+103|-0.646541048786|0.0549768859803|-0.90048907572
+104|-0.607672440835|-0.0597969416712|-0.92235505661
+105|-0.576668701754|-0.16114819775|-0.921595827864
+106|-0.545664962608|-0.262499453743|-0.920836599128
+107|-0.495939744597|-0.350772314723|-0.891987904048
+108|-0.446214526602|-0.439045175772|-0.863139209001
+109|-0.390325238349|-0.502547532515|-0.796026245014
+110|-0.334435950088|-0.566049889293|-0.728913281017
+111|-0.285302127936|-0.60987405623|-0.657186170535
+112|-0.236168305748|-0.653698223146|-0.585459059997
+113|-0.182237836481|-0.682871772143|-0.536840395466
+114|-0.128307367184|-0.712045321068|-0.488221730878
+115|-0.065927929287|-0.725020759991|-0.474793009911
+24|-0.00354849146405|-0.737996198981|-0.461364289002
+117|0.0638253091608|-0.735566080866|-0.463813244263
+118|0.131199109787|-0.73313596268|-0.466262199443
+119|0.194286440465|-0.699811948342|-0.513521912927
+120|0.257373771199|-0.666487934013|-0.560781626342
+121|0.308598755049|-0.61804145929|-0.613723257723
+122|0.359823738861|-0.569594984479|-0.666664889106
+123|0.407851339087|-0.522433885778|-0.714432235344
+124|0.455878939366|-0.475272786996|-0.762199581605
+125|0.519515088045|-0.365356986341|-0.828822422342
+126|0.583151236655|-0.255441185628|-0.89544526313
+127|0.609093253652|-0.166017120718|-0.910235122162
+128|0.635035270609|-0.0765930558831|-0.92502498115
+129|0.661799095895|0.0252513229862|-0.915635837119
+130|0.688562921268|0.127095701883|-0.906246693121
+131|0.700222989101|0.229632887961|-0.904118559673
+132|0.711883056935|0.332170073985|-0.901990426154
+1|-0.552168877879|0.483792334207|-0.466456539275
+134|-0.45981092498|0.561780416182|-0.30816093354
+2|-0.360178576451|0.566290707559|-0.216753374882
+136|-0.249770054109|0.53033674519|-0.148694799479
+3|-0.101941089981|0.482471777751|-0.124216132105
+138|-0.100790757462|0.429593539303|-0.13295377507
+139|-0.24861972159|0.477458506742|-0.157432442443
+140|-0.359028243932|0.513412469111|-0.225491017847
+141|-0.458660592461|0.508902177734|-0.316898576504
+4|0.0979764884324|0.465935806417|-0.131673602172
+143|0.214281256932|0.49713048824|-0.135680573937
+5|0.373011222167|0.520379035114|-0.223000042464
+145|0.501607216528|0.500371303275|-0.31925158046
+6|0.573625223791|0.457842172777|-0.465084060803
+147|0.500456884008|0.447493064827|-0.310513937496
+148|0.374161554686|0.467500796666|-0.231737685429
+149|0.215431589451|0.444252249792|-0.144418216902
+150|0.0991268209517|0.413057567969|-0.140411245136
+151|0.00843526194114|0.416310824526|-0.139142361862
+152|0.00731217835118|0.267155736618|-0.106471830333
+153|0.00190174006432|0.142359799356|-0.0544539305899
+17|0.0|0.0|0.0
+16|-0.118903311783|-0.0792784542252|-0.23553779698
+156|-0.0504300242939|-0.104665185108|-0.171210015652
+157|0.00312657092558|-0.111845126622|-0.154139340027
+158|0.0687543983424|-0.0995458563331|-0.173275169304
+18|0.148988810614|-0.0737848514843|-0.23851346235
+7|-0.406302776084|0.360389456414|-0.383328256738
+161|-0.34477414508|0.384294131116|-0.310253684249
+9|-0.303148027928|0.398970401478|-0.298449428843
+163|-0.261521910707|0.396020592392|-0.286645173373
+8|-0.160787895283|0.371046071665|-0.33024323701
+165|-0.245621576158|0.334421719595|-0.323105980144
+10|-0.286856859665|0.324384347457|-0.33152448747
+167|-0.328092143155|0.331973054752|-0.339942994736
+11|0.163880511996|0.340747336431|-0.314140711048
+169|0.250844605212|0.392490504075|-0.279030578852
+13|0.295807479997|0.39454861469|-0.282436250152
+171|0.340770354713|0.378980645804|-0.285841921386
+12|0.432352000574|0.348665008377|-0.377158168879
+173|0.333467711265|0.328033806608|-0.318530865652
+14|0.290044551442|0.322730607606|-0.315834032602
+175|0.246621391603|0.335053488012|-0.313137199533
+20|-0.2229832011|-0.277627584818|-0.354306324759
+177|-0.157062857603|-0.248737641675|-0.270218764688
+178|-0.0816354856176|-0.23618498428|-0.225237980339
+22|-0.00391577759461|-0.24499812323|-0.212609820801
+180|0.0629421066443|-0.230774971513|-0.20850127049
+181|0.135507933036|-0.240646328002|-0.235471195707
+21|0.240381217797|-0.283709533359|-0.332281891846
+183|0.143706630542|-0.329461104519|-0.283978333058
+184|0.0841399519463|-0.35535473325|-0.256464077312
+23|0.00813704018926|-0.360357397838|-0.250423014186
+186|-0.0778260690653|-0.349041413147|-0.269661366141
+187|-0.157111645515|-0.311971743235|-0.322383820174
+188|-0.182660017758|-0.278377925956|-0.341170846657
+189|-0.0759434753887|-0.283519721316|-0.270854173882
+190|-0.00267013225455|-0.286917503124|-0.254428375697
+191|0.068506200339|-0.285630836758|-0.257845545665
+192|0.194170296105|-0.279641061938|-0.31460554285
+193|0.0689862474834|-0.288076969827|-0.25869654526
+194|-0.00273913414368|-0.28915723747|-0.25622946708
+195|-0.0766989697367|-0.286311968999|-0.273027820052
+196|-0.291183407723|0.363024170951|-0.328919887677
+197|0.294656095888|0.353995131537|-0.31463990759
diff --git a/SPIGA/spiga/data/models3D/visualization.py b/SPIGA/spiga/data/models3D/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..49bf883d7194039d260ea1dfb70eecd2454c8fd1
--- /dev/null
+++ b/SPIGA/spiga/data/models3D/visualization.py
@@ -0,0 +1,37 @@
+import argparse
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+def main():
+ # Input arguments control
+ pars = argparse.ArgumentParser(description='3D model visualization')
+ pars.add_argument('file', type=str, help='File txt path')
+ args = pars.parse_args()
+ visualize_3Dmodel(args.file)
+
+
+def visualize_3Dmodel(input_file):
+
+ with open(input_file) as f:
+ lines = f.readlines()
+
+ model = []
+ for line in lines:
+ line = line[:-1] # Remove \n
+ line_split = line.split('|')
+ values = np.array(line_split, dtype=float)
+ model.append(values)
+
+ model = np.array(model)
+ model_xyz = model[:, 1:]
+
+ # Show model
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection='3d')
+ ax.scatter(model_xyz[:, 0], model_xyz[:, 1], model_xyz[:, 2]+0.8)
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/SPIGA/spiga/data/readme.md b/SPIGA/spiga/data/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..a30d1e2503c1f8a62d186dd183bdf769724e225d
--- /dev/null
+++ b/SPIGA/spiga/data/readme.md
@@ -0,0 +1,72 @@
+# SPIGA: Dataloaders
+Currently, the repository provides a pytorch based image dataloader implementation for the task of:
+* **Facial Landmarks estimation**.
+* **Headpose estimation**.
+* **Facial landmarks visibilities**.
+
+The dataloader can be used for training or testing the networks and it includes general and
+specifics data augmentation for each task, as it can be image partial occlusions
+or headpose generation from facial landmarks.
+
+In addition, the framework provides a wide benchmark software in order to evaluate the
+different task along the following databases:
+* **WFLW**.
+* **MERLRAV (AFLW 68)**
+* **COFW68**.
+* **300W Public, Private**.
+
+ Coming soon...
+
+* AFLW, AFLW19, AFLW2000 (test only).
+* Menpo and 3D Menpo.
+* COFW.
+* 300WLP
+* 300W Masked
+
+
+
+***Note:*** All the callable files provide a detailed parser that describes the behaviour of
+the program and their inputs. Please, check the operational modes by using the extension ```--help```.
+
+## Training/Testing
+The dataloader structure can be found in ```./data/loaders/aligments.py``` and it can be
+manually controlled by instantiating the class ```AlignmentsDataset()``` or by using
+the ```data_config``` structure available in ```./data/loaders/dl_config.py```.
+
+Each image sample will follow the next configuration:
+```
+sample = {'image': Data augmented crop image,
+ 'sample_idx': Image ID,
+ 'imgpath': Absolute path to raw image,
+ 'imgpath_local': Relative path to raw image,
+ 'ids_ldm': Landmarks ids,
+ 'bbox': Face bbox [x,y,w,h] (ref crop),
+ 'bbox_raw': Face bbox [x,y,w,h] (ref image),
+ 'landmarks': Augmented landmarks [[x1,y1], [x2,y2], ...] (ref crop)
+ 'visible': Visibilities [0,1, ...] (1 == Visible)
+ 'mask_ldm': Available landmarks anns [True, False, ...] <- len(ids_ldm)
+ 'headpose': Augmented POSIT headpose [yaw, pithc, roll]
+ }
+
+Extra features while debugging:
+sample = { ...
+ 'landmarks_ori' = Ground truth landmarks before augmentation (ref image)
+ 'visible_ori' = Ground truth visibilities before augmentation
+ 'mask_ldm_ori' = Ground truth mask before augmentation
+ 'headpose_ori' = Ground truth headpose before augmentation (if available)
+ }
+```
+
+## Visualizers
+The dataloader framework provides complementary visualizers to further understand the databases,
+datasets and their difficulties:
+
+* ```./data/visualize/inspect_dataset.py```
+Focus on the database annotations and the data augmentations, which allows us to
+understand the training, validation and test datasets.
+
+* ```./data/visualize/inspect_heatmaps.py```
+Extended visualizer version focus on understanding the heatmaps and boundaries features used for training.
+
+* ```./data/model3D/visualization.py```
+Visualize the rigid facial 3D models used by SPIGA to project the initial coordinates of the GAT regressor.
diff --git a/SPIGA/spiga/data/visualize/__init__.py b/SPIGA/spiga/data/visualize/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/data/visualize/inspect_dataset.py b/SPIGA/spiga/data/visualize/inspect_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fac34324fa47f455e096644585a147400a19e53
--- /dev/null
+++ b/SPIGA/spiga/data/visualize/inspect_dataset.py
@@ -0,0 +1,181 @@
+import cv2
+import random
+import numpy as np
+
+import spiga.data.loaders.dl_config as dl_cfg
+import spiga.data.loaders.dataloader as dl
+import spiga.data.visualize.plotting as plot
+
+
+def inspect_parser():
+ import argparse
+ pars = argparse.ArgumentParser(description='Data augmentation and dataset visualization. '
+ 'Press Q to quit,'
+ 'N to visualize the next image'
+ ' and any other key to visualize the next default data.')
+ pars.add_argument('database', type=str,
+ choices=['wflw', '300wpublic', '300wprivate', 'cofw68', 'merlrav'], help='Database name')
+ pars.add_argument('-a', '--anns', type=str, default='train', help='Annotation type: test, train or valid')
+ pars.add_argument('-np', '--nopose', action='store_false', default=True, help='Avoid pose generation')
+ pars.add_argument('-c', '--clean', action='store_true', help='Process without data augmentation for train')
+ pars.add_argument('--shape', nargs='+', type=int, default=[256, 256], help='Image cropped shape (W,H)')
+ pars.add_argument('--img', nargs='+', type=int, default=None, help='Select specific image ids')
+ return pars.parse_args()
+
+
+class DatasetInspector:
+
+ def __init__(self, database, anns_type, data_aug=True, pose=True, image_shape=(256,256)):
+
+ data_config = dl_cfg.AlignConfig(database, anns_type)
+ data_config.image_size = image_shape
+ data_config.ftmap_size = image_shape
+ data_config.generate_pose = pose
+
+ if not data_aug:
+ data_config.aug_names = []
+
+ self.data_config = data_config
+ dataloader, dataset = dl.get_dataloader(1, data_config, debug=True)
+ self.dataset = dataset
+ self.dataloader = dataloader
+ self.colors_dft = {'lnd': (plot.GREEN, plot.RED), 'pose': (plot.BLUE, plot.GREEN, plot.RED)}
+
+ def show_dataset(self, ids_list=None):
+
+ if ids_list is None:
+ ids = self.get_idx(shuffle=self.data_config.shuffle)
+ else:
+ ids = ids_list
+
+ for img_id in ids:
+ data_dict = self.dataset[img_id]
+ crop_imgs, full_img = self.plot_features(data_dict)
+
+ # Plot crop
+ if 'merge' in crop_imgs.keys():
+ crop = crop_imgs['merge']
+ else:
+ crop = crop_imgs['lnd']
+ cv2.imshow('crop', crop)
+
+ # Plot full
+ cv2.imshow('image', full_img['lnd'])
+
+ key = cv2.waitKey()
+ if key == ord('q'):
+ break
+
+ def plot_features(self, data_dict, colors=None):
+
+ # Init variables
+ crop_imgs = {}
+ full_imgs = {}
+ if colors is None:
+ colors = self.colors_dft
+
+ # Cropped image
+ image = data_dict['image']
+ landmarks = data_dict['landmarks']
+ visible = data_dict['visible']
+ if np.any(np.isnan(visible)):
+ visible = None
+ mask = data_dict['mask_ldm']
+
+ # Full image
+ if 'image_ori' in data_dict.keys():
+ image_ori = data_dict['image_ori']
+ else:
+ image_ori = cv2.imread(data_dict['imgpath'])
+ landmarks_ori = data_dict['landmarks_ori']
+ visible_ori = data_dict['visible_ori']
+ if np.any(np.isnan(visible_ori)):
+ visible_ori = None
+ mask_ori = data_dict['mask_ldm_ori']
+
+ # Plot landmarks
+ crop_imgs['lnd'] = self._plot_lnd(image, landmarks, visible, mask, colors=colors['lnd'])
+ full_imgs['lnd'] = self._plot_lnd(image_ori, landmarks_ori, visible_ori, mask_ori, colors=colors['lnd'])
+
+ if self.data_config.generate_pose:
+ rot, trl, cam_matrix = self._extract_pose(data_dict)
+
+ # Plot pose
+ crop_imgs['pose'] = plot.draw_pose(image, rot, trl, cam_matrix, euler=True, colors=colors['pose'])
+
+ # Plot merge features
+ crop_imgs['merge'] = plot.draw_pose(crop_imgs['lnd'], rot, trl, cam_matrix, euler=True, colors=colors['pose'])
+
+ return crop_imgs, full_imgs
+
+ def get_idx(self, shuffle=False):
+ ids = list(range(len(self.dataset)))
+ if shuffle:
+ random.shuffle(ids)
+ return ids
+
+ def reload_dataset(self, data_config=None):
+ if data_config is None:
+ data_config = self.data_config
+ dataloader, dataset = dl.get_dataloader(1, data_config, debug=True)
+ self.dataset = dataset
+ self.dataloader = dataloader
+
+ def _extract_pose(self, data_dict):
+ # Rotation and translation matrix
+ pose = data_dict['pose']
+ rot = pose[:3]
+ trl = pose[3:]
+
+ # Camera matrix
+ cam_matrix = data_dict['cam_matrix']
+
+ # Check for ground truth anns
+ if 'headpose_ori' in data_dict.keys():
+ if len(self.data_config.aug_names) == 0:
+ print('Image headpose generated by ground truth data')
+ pose_ori = data_dict['headpose_ori']
+ rot = pose_ori
+
+ return rot, trl, cam_matrix
+
+ def _plot_lnd(self, image, landmarks, visible, mask, max_shape_thr=720, colors=None):
+
+ if colors is None:
+ colors = self.colors_dft['lnd']
+
+ # Full image plots
+ W, H, C = image.shape
+
+ # Original image resize if need it
+ if W > max_shape_thr or H > max_shape_thr:
+ max_shape = max(W, H)
+ scale_factor = max_shape_thr / max_shape
+ resize_shape = (int(H * scale_factor), int(W * scale_factor))
+ image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask,
+ thick_scale=1 / scale_factor, colors=colors)
+ image_out = cv2.resize(image_out, resize_shape)
+ else:
+ image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask, colors=colors)
+
+ return image_out
+
+
+if __name__ == '__main__':
+ args = inspect_parser()
+ data_aug = True
+ database = args.database
+ anns_type = args.anns
+ pose = args.nopose
+ select_img = args.img
+ if args.clean:
+ data_aug = False
+
+ if len(args.shape) != 2:
+ raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256')
+ else:
+ img_shape = tuple(args.shape)
+
+ visualizer = DatasetInspector(database, anns_type, data_aug=data_aug, pose=pose, image_shape=img_shape)
+ visualizer.show_dataset(ids_list=select_img)
+
diff --git a/SPIGA/spiga/data/visualize/inspect_heatmaps.py b/SPIGA/spiga/data/visualize/inspect_heatmaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..63527ccfe2b4a57b1fce81d01da8e2ba51b0bac1
--- /dev/null
+++ b/SPIGA/spiga/data/visualize/inspect_heatmaps.py
@@ -0,0 +1,92 @@
+import cv2
+import numpy as np
+
+from spiga.data.visualize.inspect_dataset import DatasetInspector, inspect_parser
+
+
+class HeatmapInspector(DatasetInspector):
+
+ def __init__(self, database, anns_type, data_aug=True, image_shape=(256,256)):
+
+ super().__init__(database, anns_type, data_aug=data_aug, pose=False, image_shape=image_shape)
+
+ self.data_config.aug_names.append('heatmaps2D')
+ self.data_config.heatmap2D_norm = False
+ self.data_config.aug_names.append('boundaries')
+ self.data_config.shuffle = False
+ self.reload_dataset()
+
+ def show_dataset(self, ids_list=None):
+
+ if ids_list is None:
+ ids = self.get_idx(shuffle=self.data_config.shuffle)
+ else:
+ ids = ids_list
+
+ for img_id in ids:
+ data_dict = self.dataset[img_id]
+
+ crop_imgs, _ = self.plot_features(data_dict)
+
+ # Plot landmark crop
+ cv2.imshow('crop', crop_imgs['lnd'])
+
+ # Plot landmarks 2D (group)
+ crop_allheats = self._plot_heatmaps2D(data_dict)
+
+ # Plot boundaries shape
+ cv2.imshow('boundary', np.max(data_dict['boundary'], axis=0))
+
+ for lnd_idx in range(self.data_config.database.num_landmarks):
+ # Heatmaps 2D
+ crop_heats = self._plot_heatmaps2D(data_dict, lnd_idx)
+ maps = cv2.hconcat([crop_allheats['heatmaps2D'], crop_heats['heatmaps2D']])
+ cv2.imshow('heatmaps', maps)
+
+ key = cv2.waitKey()
+ if key == ord('q'):
+ break
+ if key == ord('n'):
+ break
+
+ if key == ord('q'):
+ break
+
+ def _plot_heatmaps2D(self, data_dict, heatmap_id=None):
+
+ # Variables
+ heatmaps = {}
+ image = data_dict['image']
+
+ if heatmap_id is None:
+ heatmaps2D = data_dict['heatmap2D']
+ heatmaps2D = np.max(heatmaps2D, axis=0)
+ else:
+ heatmaps2D = data_dict['heatmap2D'][heatmap_id]
+
+ # Plot maps
+ heatmaps['heatmaps2D'] = self._merge_imgmap(image, heatmaps2D)
+ return heatmaps
+
+ def _merge_imgmap(self, image, maps):
+ crop_maps = cv2.applyColorMap(np.uint8(255 * maps), cv2.COLORMAP_JET)
+ return cv2.addWeighted(image, 0.7, crop_maps, 0.3, 0)
+
+
+if __name__ == '__main__':
+
+ args = inspect_parser()
+ data_aug = True
+ database = args.database
+ anns_type = args.anns
+ select_img = args.img
+ if args.clean:
+ data_aug = False
+
+ if len(args.shape) != 2:
+ raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256')
+ else:
+ img_shape = tuple(args.shape)
+
+ visualizer = HeatmapInspector(database, anns_type, data_aug, image_shape=img_shape)
+ visualizer.show_dataset(ids_list=select_img)
diff --git a/SPIGA/spiga/data/visualize/plotting.py b/SPIGA/spiga/data/visualize/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..76661f33cf61807a957def360d07277bbd7db6e3
--- /dev/null
+++ b/SPIGA/spiga/data/visualize/plotting.py
@@ -0,0 +1,97 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import cv2
+
+import spiga.data.loaders.augmentors.utils as dlu
+
+BLUE = (255, 0, 0)
+GREEN = (0, 255, 0)
+RED = (0, 0, 255)
+PURPLE = (128, 0, 128)
+
+
+def draw_landmarks(image, landmarks, visible=None, mask=None, thick_scale=1, colors=(GREEN, RED)):
+ # Fix variable
+ thick = int(2 * thick_scale + 0.5)
+ # Initialize variables if need it
+ if visible is None:
+ visible = np.ones(len(landmarks))
+ if mask is None:
+ mask = np.ones(len(landmarks))
+
+ mask = np.array(mask, dtype=bool)
+ visible = np.array(visible, dtype=bool)
+
+ # Clean and split landmarks
+ landmarks = landmarks[mask]
+ visible = visible[mask]
+ ldm_vis = landmarks[visible]
+ not_visible = np.logical_not(visible)
+ ldm_notvis = landmarks[not_visible]
+
+ # Plot landmarks
+ if image.shape[0] == 3:
+ image = image.transpose(1, 2, 0)
+
+ canvas = image.copy()
+ canvas = _write_circles(canvas, ldm_vis, color=colors[0], thick=thick)
+ canvas = _write_circles(canvas, ldm_notvis, color=colors[1], thick=thick)
+
+ return canvas
+
+
+def _write_circles(canvas, landmarks, color=RED, thick=2):
+ for xy in landmarks:
+ xy = np.array(xy+0.5, dtype=int)
+ canvas = cv2.circle(canvas, (xy[0], xy[1]), thick, color, -1)
+ return canvas
+
+
+def plot_landmarks_pil(image, landmarks, visible=None, mask=None):
+
+ # Initialize variables if need it
+ if visible is None:
+ visible = np.ones(len(landmarks))
+ if mask is None:
+ mask = np.ones(len(landmarks))
+
+ mask = np.array(mask, dtype=bool)
+ visible = np.array(visible, dtype=bool)
+ not_visible = np.logical_not(visible)
+
+ # Clean and split landmarks
+ landmarks = landmarks[mask]
+ ldm_vis = landmarks[visible]
+ ldm_notvis = landmarks[not_visible]
+
+ # Plot landmarks
+ if image.shape[0] == 3:
+ image = image.transpose(1, 2, 0)
+
+ plt.imshow(image / 255)
+ plt.scatter(ldm_vis[:, 0], ldm_vis[:, 1], s=10, marker='.', c='g')
+ plt.scatter(ldm_notvis[:, 0], ldm_notvis[:, 1], s=10, marker='.', c='r')
+ plt.show()
+
+
+def draw_pose(img, rot, trl, K, euler=False, size=0.5, colors=(BLUE, GREEN, RED)):
+ if euler:
+ rot = dlu.euler_to_rotation_matrix(rot)
+
+ canvas = img.copy()
+ rotV, _ = cv2.Rodrigues(rot)
+ points = np.float32([[size, 0, 0], [0, -size, 0], [0, 0, -size], [0, 0, 0]]).reshape(-1, 3)
+ axisPoints, _ = cv2.projectPoints(points, rotV, trl, K, (0, 0, 0, 0))
+ axisPoints = axisPoints.astype(int)
+ canvas = cv2.line(canvas, tuple(axisPoints[3].ravel()), tuple(axisPoints[2].ravel()), colors[0], 3)
+ canvas = cv2.line(canvas, tuple(axisPoints[3].ravel()), tuple(axisPoints[1].ravel()), colors[1], 3)
+ canvas = cv2.line(canvas, tuple(axisPoints[3].ravel()), tuple(axisPoints[0].ravel()), colors[2], 3)
+
+ return canvas
+
+
+def enhance_heatmap(heatmap):
+ map_aux = heatmap - heatmap.min()
+ map_aux = map_aux / map_aux.max()
+ map_img = cv2.applyColorMap((map_aux * 255).astype(np.uint8), cv2.COLORMAP_BONE)
+ return map_img
diff --git a/SPIGA/spiga/demo/__init__.py b/SPIGA/spiga/demo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/analyze/__init__.py b/SPIGA/spiga/demo/analyze/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/analyze/analyzer.py b/SPIGA/spiga/demo/analyze/analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..33b7514796661ea3b2dac6294aa32880dc4bd772
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/analyzer.py
@@ -0,0 +1,53 @@
+import copy
+
+# Demo libs
+import spiga.demo.analyze.extract.processor as pr
+
+
+class VideoAnalyzer:
+ def __init__(self, tracker, processor=pr.EmptyProcessor()):
+ self.tracker = tracker
+ self.processor = processor
+ self.tracked_obj = []
+
+ def process_frame(self, image):
+ image = copy.copy(image)
+ self.tracked_obj = self.tracker.process_frame(image, self.tracked_obj)
+ if len(self.tracked_obj) > 0:
+ self.tracked_obj = self.processor.process_frame(image, self.tracked_obj)
+ self.tracked_obj = self._add_attributes()
+ return self.tracked_obj
+
+ def plot_features(self, image, plotter, show_attributes):
+ for obj in self.tracked_obj:
+ image = obj.plot_features(image, plotter, show_attributes)
+ return image
+
+ def get_attributes(self, names):
+
+ # Check input type
+ single_name = False
+ if isinstance(names, str):
+ names = [names]
+ single_name = True
+
+ attributes = {}
+ for name in names:
+ attribute = []
+ for obj in self.tracked_obj:
+ attribute.append(obj.get_attributes(name))
+ attributes[name] = attribute
+
+ if single_name:
+ return attribute
+ else:
+ return attributes
+
+ def _add_attributes(self):
+ for obj in self.tracked_obj:
+ if not obj.has_processor():
+ obj.attributes += self.processor.attributes
+ obj.attributes += self.tracker.attributes
+ obj.drawers.append(self.processor.plot_features)
+ obj.drawers.append(self.tracker.plot_features)
+ return self.tracked_obj
diff --git a/SPIGA/spiga/demo/analyze/extract/__init__.py b/SPIGA/spiga/demo/analyze/extract/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/analyze/extract/processor.py b/SPIGA/spiga/demo/analyze/extract/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e62fd89cce5031670c7a63c98a8f27df0dc332c
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/extract/processor.py
@@ -0,0 +1,58 @@
+
+
+class Processor:
+ def __init__(self):
+ self.attributes = []
+
+ def process_frame(self, frame, tracked_obj):
+ """
+ Process tracked objects to extract interesting features.
+ :param frame: OpenCV image.
+ :param tracked_obj: List with the objects to be processed.
+ """
+ raise NotImplementedError()
+
+ def plot_features(self, image, features, plotter, show_attributes):
+ """
+ Visualize objects detected in the input image.
+ :param image: OpenCV image.
+ :param features: List of object features detect after processing the frame.
+ :param plotter: Plotter interface.
+ :param show_attributes: Selected object attributes to be displayed.
+ """
+ raise NotImplementedError()
+
+
+class EmptyProcessor(Processor):
+ def __init__(self):
+ super().__init__()
+
+ def process_frame(self, frame, tracked_obj):
+ return tracked_obj
+
+ def plot_features(self, image, features, plotter, show_attributes):
+ return image
+
+
+class ProcessorsGroup(Processor):
+ def __init__(self):
+ super().__init__()
+ self.group = []
+
+ def process_frame(self, frame, tracked_obj):
+ for elem in self.group:
+ tracked_obj = elem.process_frame(frame, tracked_obj)
+ return tracked_obj
+
+ def plot_features(self, image, features, plotter, show_attributes):
+ for elem in self.group:
+ image = elem.plot_features(image, features, plotter, show_attributes)
+ return image
+
+ def add_processor(self, processor):
+ self.group.append(processor)
+ self.attributes += processor.attributes
+
+ def get_number_of_processors(self):
+ return len(self.group)
+
diff --git a/SPIGA/spiga/demo/analyze/extract/spiga_processor.py b/SPIGA/spiga/demo/analyze/extract/spiga_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e0ef38c20f807f310cd3d7b7bcc2f20f74f7149
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/extract/spiga_processor.py
@@ -0,0 +1,60 @@
+# SPIGA library
+import spiga.inference.config as model_cfg
+from spiga.inference.framework import SPIGAFramework
+
+# Demo modules
+import spiga.demo.analyze.extract.processor as pr
+
+
+class SPIGAProcessor(pr.Processor):
+
+ def __init__(self,
+ dataset='wflw',
+ features=('lnd', 'pose'),
+ gpus=[0]):
+
+ super().__init__()
+
+ # Configure and load processor
+ self.processor_cfg = model_cfg.ModelConfig(dataset)
+ self.processor = SPIGAFramework(self.processor_cfg, gpus=gpus)
+
+ # Define attributes
+ if 'lnd' in features:
+ self.attributes.append('landmarks')
+ self.attributes.append('landmarks_ids')
+ if 'pose' in features:
+ self.attributes.append('headpose')
+
+ def process_frame(self, frame, tracked_obj):
+ bboxes = []
+ for obj in tracked_obj:
+ x1, y1, x2, y2 = obj.bbox[:4]
+ bbox_wh = [x1, y1, x2-x1, y2-y1]
+ bboxes.append(bbox_wh)
+ features = self.processor.inference(frame, bboxes)
+
+ for obj_idx in range(len(tracked_obj)):
+ # Landmarks output
+ if 'landmarks' in self.attributes:
+ tracked_obj[obj_idx].landmarks = features['landmarks'][obj_idx]
+ tracked_obj[obj_idx].landmarks_ids = self.processor_cfg.dataset.ldm_ids
+ # Headpose output
+ if 'headpose' in self.attributes:
+ tracked_obj[obj_idx].headpose = features['headpose'][obj_idx]
+
+ return tracked_obj
+
+ def plot_features(self, image, features, plotter, show_attributes):
+
+ if 'landmarks' in self.attributes and 'landmarks' in show_attributes:
+ x1, y1, x2, y2 = features.bbox[:4]
+ thick = int(plotter.landmarks.thickness['lnd'] * (x2-x1)/200 + 0.5)
+ if thick == 0:
+ thick = 1
+ image = plotter.landmarks.draw_landmarks(image, features.landmarks, thick=thick)
+
+ if 'headpose' in self.attributes and 'headpose' in show_attributes:
+ image = plotter.hpose.draw_headpose(image, features.bbox[:5],
+ features.headpose[:3], features.headpose[3:], euler=True)
+ return image
diff --git a/SPIGA/spiga/demo/analyze/features/__init__.py b/SPIGA/spiga/demo/analyze/features/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/analyze/features/basic.py b/SPIGA/spiga/demo/analyze/features/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d3dbae466c5aa3f33866b82cd51e470b7111b1
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/features/basic.py
@@ -0,0 +1,40 @@
+
+class ObjectAnalyzed:
+
+ def __init__(self):
+ # Processor addons
+ self.attributes = []
+ self.drawers = []
+
+ def has_processor(self):
+ if len(self.attributes) > 0:
+ return True
+ else:
+ return False
+
+ def plot_features(self, image, plotter, show_attributes):
+ for drawer in self.drawers:
+ image = drawer(image, self, plotter, show_attributes)
+ return image
+
+ def get_attributes(self, names=None):
+
+ # Initialization by input type
+ single_name = False
+ if names is None:
+ names = self.attributes
+ elif isinstance(names, str):
+ names = [names]
+ single_name = True
+
+ attributes = {}
+ attribute = []
+ for name in names:
+ if name in self.attributes and name in self.__dict__.keys():
+ attribute = getattr(self, name)
+ attributes[name] = attribute
+
+ if single_name:
+ return attribute
+ else:
+ return attributes
\ No newline at end of file
diff --git a/SPIGA/spiga/demo/analyze/features/face.py b/SPIGA/spiga/demo/analyze/features/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf15b49f2609be3f4d037994fbec50083294ca4e
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/features/face.py
@@ -0,0 +1,20 @@
+import numpy as np
+
+# Demo libs
+from spiga.demo.analyze.features.basic import ObjectAnalyzed
+
+
+class Face(ObjectAnalyzed):
+
+ def __init__(self):
+ super().__init__()
+ self.bbox = np.zeros(5)
+ self.key_landmarks = - np.ones((5, 2))
+ self.landmarks = None
+ self.face_id = -1
+ self.past_states = []
+ self.num_past_states = 5
+
+
+
+
diff --git a/SPIGA/spiga/demo/analyze/track/__init__.py b/SPIGA/spiga/demo/analyze/track/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/analyze/track/get_tracker.py b/SPIGA/spiga/demo/analyze/track/get_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8495b47938dc557e01944b784f803e2b8f3db9e
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/track/get_tracker.py
@@ -0,0 +1,13 @@
+# Demo libs
+import spiga.demo.analyze.track.retinasort.zoo as zoo_rs
+
+zoos = [zoo_rs]
+
+
+def get_tracker(model_name):
+ for zoo in zoos:
+ model = zoo.get_tracker(model_name)
+ if model is not None:
+ return model
+
+ raise NotImplementedError('Tracker name not available')
diff --git a/SPIGA/spiga/demo/analyze/track/retinasort/__init__.py b/SPIGA/spiga/demo/analyze/track/retinasort/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/analyze/track/retinasort/config.py b/SPIGA/spiga/demo/analyze/track/retinasort/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0c3ca78e73af40d4e4bd5500059986702c49654
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/track/retinasort/config.py
@@ -0,0 +1,81 @@
+
+
+cfg_retinasort = {
+
+ 'retina': {
+ 'model_name': 'mobile0.25',
+ 'extra_features': ['landmarks'],
+ 'postreat': {
+ 'resize': 1.,
+ 'score_thr': 0.75,
+ 'top_k': 5000,
+ 'nms_thr': 0.4,
+ 'keep_top_k': 50}
+ },
+
+ 'sort': {
+ 'max_age': 1,
+ 'min_hits': 3,
+ 'iou_threshold': 0.3,
+ }
+}
+
+cfg_retinasort_res50 = {
+
+ 'retina': {
+ 'model_name': 'resnet50',
+ 'extra_features': ['landmarks'],
+ 'postreat': {
+ 'resize': 1.,
+ 'score_thr': 0.75,
+ 'top_k': 5000,
+ 'nms_thr': 0.4,
+ 'keep_top_k': 50}
+ },
+
+ 'sort': {
+ 'max_age': 1,
+ 'min_hits': 3,
+ 'iou_threshold': 0.3,
+ }
+}
+
+cfg_retinasort_cav3d = {
+
+ 'retina': {
+ 'model_name': 'resnet50',
+ 'extra_features': ['landmarks'],
+ 'postreat': {
+ 'resize': 1.,
+ 'score_thr': 0.95,
+ 'top_k': 5000,
+ 'nms_thr': 0.8,
+ 'keep_top_k': 50}
+ },
+
+ 'sort': {
+ 'max_age': 90,
+ 'min_hits': 3,
+ 'iou_threshold': 0.3,
+ }
+}
+
+cfg_retinasort_av16 = {
+
+ 'retina': {
+ 'model_name': 'resnet50',
+ 'extra_features': ['landmarks'],
+ 'postreat': {
+ 'resize': 1.,
+ 'score_thr': 0.75,
+ 'top_k': 5000,
+ 'nms_thr': 0.8,
+ 'keep_top_k': 50}
+ },
+
+ 'sort': {
+ 'max_age': 90,
+ 'min_hits': 3,
+ 'iou_threshold': 0.3,
+ }
+}
diff --git a/SPIGA/spiga/demo/analyze/track/retinasort/face_tracker.py b/SPIGA/spiga/demo/analyze/track/retinasort/face_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb2e4e1af749098477734e6f9a0112d9dabeaec
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/track/retinasort/face_tracker.py
@@ -0,0 +1,82 @@
+import numpy as np
+
+# Third party algorithms. Implementation maintained by SPIGA authors.
+import sort_tracker
+import retinaface
+
+# My libs
+import spiga.demo.analyze.track.retinasort.config as cfg
+import spiga.demo.analyze.track.tracker as tracker
+import spiga.demo.analyze.features.face as ft_face
+
+
+class RetinaSortTracker(tracker.Tracker):
+
+ def __init__(self, config=cfg.cfg_retinasort, device='cuda'):
+ super().__init__()
+
+ self.detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name'],
+ device=device,
+ extra_features=config['retina']['extra_features'],
+ cfg_postreat=config['retina']['postreat'])
+
+ self.associator = sort_tracker.Sort(max_age=config['sort']['max_age'],
+ min_hits=config['sort']['min_hits'],
+ iou_threshold=config['sort']['iou_threshold'])
+ self.obj_type = ft_face.Face
+ self.attributes += ['bbox', 'face_id', 'key_landmarks']
+
+ def process_frame(self, image, tracked_obj):
+ # tracked_obj = []
+ features = self.detector.inference(image)
+ bboxes = features['bbox']
+ bboxes = self._code_bbox_idx(bboxes)
+ bboxes_id = self.associator.update(bboxes)
+ bboxes_id, bbox_idx = self._decode_bbox_idx(bboxes_id)
+ final_tracked_obj = []
+ for idx, bbox in enumerate(bboxes_id):
+ founded_flag = False
+ for past_obj in tracked_obj:
+ if past_obj.face_id == bbox[-1]:
+ past_obj.bbox = bbox[:5]
+ past_obj = self._update_extra_features(
+ past_obj, features, bbox_idx[idx])
+ final_tracked_obj.append(past_obj)
+ tracked_obj.remove(past_obj)
+ founded_flag = True
+ break
+
+ if not founded_flag:
+ new_obj = self.obj_type()
+ new_obj.bbox = bbox[:5]
+ new_obj.face_id = bbox[5]
+ new_obj = self._update_extra_features(
+ new_obj, features, bbox_idx[idx])
+ final_tracked_obj.append(new_obj)
+
+ return final_tracked_obj
+
+ def plot_features(self, image, features, plotter, show_attributes):
+ if 'bbox' in show_attributes:
+ image = plotter.bbox.draw_bbox(image, features.bbox)
+ if 'face_id' in show_attributes:
+ text_id = 'Face Id: %i' % features.face_id
+ image = plotter.bbox.draw_bbox_text(image, features.bbox, text_id, offset=(
+ 0, -10), color=plotter.basic.colors['blue'])
+ image = plotter.bbox.draw_bbox_line(image, features.bbox)
+ return image
+
+ def _code_bbox_idx(self, bboxes):
+ bboxes = np.array(bboxes)
+ bboxes[:, 4] += (np.arange(len(bboxes)) - 0.001)
+ return bboxes
+
+ def _decode_bbox_idx(self, bboxes):
+ bboxes = np.array(bboxes)
+ idx = bboxes[:, 4].astype(int)
+ bboxes[:, 4] = bboxes[:, 4] % 1 + 0.001
+ return bboxes, idx
+
+ def _update_extra_features(self, obj, features, idx):
+ obj.key_landmarks = features['landmarks'][idx]
+ return obj
diff --git a/SPIGA/spiga/demo/analyze/track/retinasort/zoo.py b/SPIGA/spiga/demo/analyze/track/retinasort/zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..19a95c06d4ba7587fc1fe97f3f9f4a74be120ac9
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/track/retinasort/zoo.py
@@ -0,0 +1,22 @@
+# My libs
+import spiga.demo.analyze.track.retinasort.face_tracker as tr
+import spiga.demo.analyze.track.retinasort.config as cfg_tr
+import torch
+
+
+def get_tracker(model_name):
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ # MobileNet Backbone
+ if model_name == 'RetinaSort':
+ return tr.RetinaSortTracker(device=device)
+ # ResNet50 Backbone
+ if model_name == 'RetinaSort_Res50':
+ return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_res50, device=device)
+ # Config CAV3D: https://ict.fbk.eu/units/speechtek/cav3d/
+ if model_name == 'RetinaSort_cav3d':
+ return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_cav3d, device=device)
+ # Config AV16: https://ict.fbk.eu/units/speechtek/cav3d/
+ if model_name == 'RetinaSort_av16':
+ return tr.RetinaSortTracker(cfg_tr.cfg_retinasort_av16, device=device)
+
+ return None
diff --git a/SPIGA/spiga/demo/analyze/track/tracker.py b/SPIGA/spiga/demo/analyze/track/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..0199ca826820bb92511a242c350696690500ed73
--- /dev/null
+++ b/SPIGA/spiga/demo/analyze/track/tracker.py
@@ -0,0 +1,25 @@
+
+class Tracker:
+ """
+ Object detection and tracking interface in a video stream
+ """
+ def __init__(self):
+ self.attributes = []
+
+ def process_frame(self, image, tracked_obj):
+ """
+ Detect and track objects in the input image.
+ :param image: OpenCV image.
+ :param tracked_obj: List with the objects found.
+ """
+ raise NotImplementedError()
+
+ def plot_features(self, image, features, plotter, show_attributes):
+ """
+ Visualize objects detected in the input image.
+ :param image: OpenCV image.
+ :param features: List of object features detect after processing the frame.
+ :param plotter: Plotter interface.
+ :param show_attributes: Selected object attributes to be displayed.
+ """
+ raise NotImplementedError()
diff --git a/SPIGA/spiga/demo/app.py b/SPIGA/spiga/demo/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..659ec537f986a11eac66bd2ec84955e7b8f8f70e
--- /dev/null
+++ b/SPIGA/spiga/demo/app.py
@@ -0,0 +1,107 @@
+import os
+import cv2
+import pkg_resources
+
+# My libs
+import spiga.demo.analyze.track.get_tracker as tr
+import spiga.demo.analyze.extract.spiga_processor as pr_spiga
+from spiga.demo.analyze.analyzer import VideoAnalyzer
+from spiga.demo.visualize.viewer import Viewer
+
+# Paths
+video_out_path_dft = pkg_resources.resource_filename('spiga', 'demo/outputs')
+if not os.path.exists(video_out_path_dft):
+ os.makedirs(video_out_path_dft)
+
+
+def main():
+ import argparse
+ pars = argparse.ArgumentParser(description='Face App')
+ pars.add_argument('-i', '--input', type=str, default='0', help='Video input')
+ pars.add_argument('-d', '--dataset', type=str, default='wflw',
+ choices=['wflw', '300wpublic', '300wprivate', 'merlrav'],
+ help='SPIGA pretrained weights per dataset')
+ pars.add_argument('-t', '--tracker', type=str, default='RetinaSort',
+ choices=['RetinaSort', 'RetinaSort_Res50'], help='Tracker name')
+ pars.add_argument('-sh', '--show', nargs='+', type=str, default=['fps', 'face_id', 'landmarks', 'headpose'],
+ choices=['fps', 'bbox', 'face_id', 'landmarks', 'headpose'],
+ help='Select the attributes of the face to be displayed ')
+ pars.add_argument('-s', '--save', action='store_true', help='Save record')
+ pars.add_argument('-nv', '--noview', action='store_false', help='Do not visualize the window')
+ pars.add_argument('--outpath', type=str, default=video_out_path_dft, help='Video output directory')
+ pars.add_argument('--fps', type=int, default=30, help='Frames per second')
+ pars.add_argument('--shape', nargs='+', type=int, help='Visualizer shape (W,H)')
+ args = pars.parse_args()
+
+ if args.shape:
+ if len(args.shape) != 2:
+ raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256')
+ else:
+ video_shape = tuple(args.shape)
+ else:
+ video_shape = None
+
+ if not args.noview and not args.save:
+ raise ValueError('No results will be saved neither shown')
+
+ video_app(args.input, spiga_dataset=args.dataset, tracker=args.tracker, fps=args.fps,
+ save=args.save, output_path=args.outpath, video_shape=video_shape, visualize=args.noview, plot=args.show)
+
+
+def video_app(input_name, spiga_dataset=None, tracker=None, fps=30, save=False,
+ output_path=video_out_path_dft, video_shape=None, visualize=True, plot=()):
+
+ # Load video
+ try:
+ capture = cv2.VideoCapture(int(input_name))
+ video_name = None
+ if not visualize:
+ print('WARNING: Webcam must be visualized in order to close the app')
+ visualize = True
+
+ except:
+ try:
+ capture = cv2.VideoCapture(input_name)
+ video_name = input_name.split('/')[-1][:-4]
+ except:
+ raise ValueError('Input video path %s not valid' % input_name)
+
+ if capture is not None:
+ # Initialize viewer
+ if video_shape is not None:
+ vid_w, vid_h = video_shape
+ else:
+ vid_w, vid_h = capture.get(3), capture.get(4)
+ viewer = Viewer('face_app', width=vid_w, height=vid_h, fps=fps)
+ if save:
+ viewer.record_video(output_path, video_name)
+
+ # Initialize face tracker
+ faces_tracker = tr.get_tracker(tracker)
+ faces_tracker.detector.set_input_shape(capture.get(4), capture.get(3))
+ # Initialize processors
+ processor = pr_spiga.SPIGAProcessor(dataset=spiga_dataset)
+ # Initialize Analyzer
+ faces_analyzer = VideoAnalyzer(faces_tracker, processor=processor)
+
+ # Convert FPS to the amount of milliseconds that each frame will be displayed
+ if visualize:
+ viewer.start_view()
+ while capture.isOpened():
+ ret, frame = capture.read()
+ if ret:
+ # Process frame
+ faces_analyzer.process_frame(frame)
+ # Show results
+ key = viewer.process_image(frame, drawers=[faces_analyzer], show_attributes=plot)
+ if key:
+ break
+ else:
+ break
+
+ capture.release()
+ viewer.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/SPIGA/spiga/demo/readme.md b/SPIGA/spiga/demo/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..17368cfd467486f6979e9e67255b642956e654d3
--- /dev/null
+++ b/SPIGA/spiga/demo/readme.md
@@ -0,0 +1,62 @@
+## Face Video-Analyzer Framework
+This demo application provides a general framework for tracking, detecting and extracting features of human faces in images or videos.
+Currently, the following tools are integrated in a video-demo available at ```./spiga/demo/app.py```:
+
+* Tracking:
+ * FaceTracker : RetinaFace + SORT.
+ * Detectors:
+ * RetinaFace (bbox + 5 landmarks).
+ * Backbone Mobilenet 0.25.
+ * Backbone Resnet50.
+ * Associators:
+ * SORT: Frame by frame
+* Extractors:
+ * SPIGA architecture
+ * Features: Landmarks and headpose.
+ * Datasets: 300W Public, 300W Private, WFLW, MERLRAV.
+* Viewers:
+ * Landmarks, headpose and bbox (score + face_id).
+
+### Demo Application
+
+```
+python ./spiga/demo/app.py \
+ [--input] \ # Webcam ID or Video Path. Dft: Webcam '0'.
+ [--dataset] \ # SPIGA pretrained weights per dataset. Dft: 'wflw'.
+ [--tracker] \ # Tracker name. Dft: 'RetinaSort'.
+ [--show] \ # Select the attributes of the face to be displayed. Dft: ['fps', 'face_id', 'landmarks', 'headpose']
+ [--save] \ # Save record.
+ [--noview] \ # Do not visualize window.
+ [--outpath] \ # Recorded output directory. Dft: './spiga/demo/outputs'
+ [--fps] \ # Frames per second.
+ [--shape] \ # Visualizer shape (W,H).
+```
+
+### Code Structure
+The demo framework has been organised according to the following structure:
+
+```
+./spiga/demo/
+| app.py # Video-demo
+|
+ββββanalyze
+β | analyzer.py # Generic video/image analyzer compositor
+β β
+β ββββfeatures # Object/Faces classes
+β β
+β ββββtrack # Task heads
+β | | tracker.py # Tracker class
+β | | get_tracker.py # Get model tracker from zoo by name
+β | ββββ retinasort # RetinaFace + SORT tracker (tracker + zoo files)
+| |
+| ββββextract
+| | processor.py # Processor classes
+| β spiga_processor.py # SPIGA wrapper
+|
+ββββvisualize
+| | viewer.py # Viewer manager
+| | plotter.py # Englobe available features drawers
+| ββββ layouts # Landmarks, bbox, headpose drawers
+|
+ββββutils # Video converters
+```
\ No newline at end of file
diff --git a/SPIGA/spiga/demo/utils/__init__.py b/SPIGA/spiga/demo/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/utils/frames2video.py b/SPIGA/spiga/demo/utils/frames2video.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaba367ff17ba2ed0ae09ecd387a98c106d4e06b
--- /dev/null
+++ b/SPIGA/spiga/demo/utils/frames2video.py
@@ -0,0 +1,62 @@
+import os
+import cv2
+
+
+def main():
+ import argparse
+ pars = argparse.ArgumentParser(description='Frames to video converter')
+ pars.add_argument('frames', type=str, help='Frames directory')
+ pars.add_argument('vidname', type=str, help='Output video name')
+ pars.add_argument('-o', '--outpath', type=str, default=None, help='Save record')
+ pars.add_argument('--fps', type=int, default=30, help='Frames per second')
+ pars.add_argument('--shape', nargs='+', type=int, help='Visualizer shape (W,H)')
+ args = pars.parse_args()
+
+ if args.shape:
+ if len(args.shape) != 2:
+ raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256')
+ else:
+ video_shape = tuple(args.shape)
+ else:
+ video_shape = None
+
+ frames2video(args.frames, args.vidname, video_path=args.outpath, video_shape=video_shape, fps=args.fps)
+
+
+def frames2video(frames_path, video_name, video_path=None, video_shape=None, fps=30):
+
+ frames_names = sorted(os.listdir(frames_path))
+
+ if video_path is None:
+ video_path = frames_path + '/vid_out/'
+
+ if not os.path.exists(video_path):
+ os.makedirs(video_path)
+
+ video_file = os.path.join(video_path, video_name + '.mp4')
+ if video_shape is None:
+ video_writer = None
+ else:
+ vid_w, vid_h = video_shape
+ video_writer = cv2.VideoWriter(video_file, cv2.VideoWriter_fourcc(*'MP4V'), fps, (vid_w, vid_h))
+
+ for frame_name in frames_names:
+ if frame_name.split('.')[-1] not in ['jpg', 'jpeg', 'png', 'tif', 'tiff', 'eps', 'bmp', 'gif']:
+ print('File %s format doesnt match with an image ' % frame_name)
+ continue
+
+ frame_file = os.path.join(frames_path, frame_name)
+ frame = cv2.imread(frame_file)
+ if video_writer is None:
+ vid_h, vid_w = frame.shape[:2]
+ video_writer = cv2.VideoWriter(video_file, cv2.VideoWriter_fourcc(*'MP4V'), fps, (vid_w, vid_h))
+
+ if frame.shape[:2] != (vid_h, vid_w):
+ frame = cv2.resize(frame, (vid_w, vid_h))
+ video_writer.write(frame)
+
+ video_writer.release()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/SPIGA/spiga/demo/visualize/__init__.py b/SPIGA/spiga/demo/visualize/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/visualize/layouts/__init__.py b/SPIGA/spiga/demo/visualize/layouts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/demo/visualize/layouts/plot_basics.py b/SPIGA/spiga/demo/visualize/layouts/plot_basics.py
new file mode 100644
index 0000000000000000000000000000000000000000..051f7c074e73e770067ae44be8ca4b3b35d91434
--- /dev/null
+++ b/SPIGA/spiga/demo/visualize/layouts/plot_basics.py
@@ -0,0 +1,40 @@
+import numpy as np
+import cv2
+
+
+class BasicLayout:
+
+ # Variables
+ colors = {'green': (0, 255, 0),
+ 'red': (0, 0, 255),
+ 'blue': (255, 0, 0),
+ 'purple': (128, 0, 128),
+ 'white': (255, 255, 255),
+ 'black': (0, 0, 0)}
+
+ thickness_dft = {'circle': 2}
+
+ def __init__(self):
+ self.thickness = self.thickness_dft
+
+ def draw_circles(self, canvas, coord_list, color=colors['red'], thick=None):
+ if thick is None:
+ thick = self.thickness['circle']
+
+ for xy in coord_list:
+ xy = np.array(xy + 0.5, dtype=int)
+ canvas = cv2.circle(canvas, (xy[0], xy[1]), thick, color, -1)
+ return canvas
+
+ def update_thickness(self, thick_dict):
+ for k, v in thick_dict.items():
+ self.thickness[k] = v
+
+ def reset_thickness(self):
+ self.thickness = self.thickness_dft
+
+ def update_thick_byratio(self, ratio_dict):
+ for key, ratio in ratio_dict.items():
+ self.thickness[key] = int(self.thickness_dft[key] * ratio + 0.5)
+
+
diff --git a/SPIGA/spiga/demo/visualize/layouts/plot_bbox.py b/SPIGA/spiga/demo/visualize/layouts/plot_bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..96911b41fe4f678ae41d5a7eac1484700952b99e
--- /dev/null
+++ b/SPIGA/spiga/demo/visualize/layouts/plot_bbox.py
@@ -0,0 +1,53 @@
+import cv2
+
+# Demo libs
+from spiga.demo.visualize.layouts.plot_basics import BasicLayout
+
+
+class BboxLayout(BasicLayout):
+
+ BasicLayout.thickness_dft['bbox'] = 2
+
+ def __init__(self):
+ super().__init__()
+
+ def draw_bbox(self, canvas, bbox, score_thr=0, show_score=True, thick=None, color=BasicLayout.colors['blue']):
+
+ if thick is None:
+ thick = self.thickness['bbox']
+
+ if bbox[4] > score_thr:
+ text = "{:.4f}".format(bbox[4])
+ b = list(map(int, bbox))
+ cv2.rectangle(canvas, (b[0], b[1]), (b[2], b[3]), color, thick)
+ if show_score:
+ self.draw_bbox_text(canvas, b, text, offset=(0, 12), color=color)
+ return canvas
+
+ def draw_bbox_line(self, canvas, bbox, score_thr=0, show_score=True, thick=None, color=BasicLayout.colors['blue']):
+
+ if thick is None:
+ thick = self.thickness['bbox']
+
+ if bbox[4] > score_thr:
+ text = "{:.4f}".format(bbox[4])
+ b = list(map(int, bbox))
+ cv2.line(canvas, (b[0], b[1]), (b[0], b[1] + 15), color, thick)
+ cv2.line(canvas, (b[0], b[1]), (b[0] + 100, b[1]), color, thick)
+ if show_score:
+ self.draw_bbox_text(canvas, b, text, offset=(0, 12), color=color)
+ return canvas
+
+ def draw_bbox_text(self, canvas, bbox, text, offset=(0, 0), color=BasicLayout.colors['white']):
+ b = list(map(int, bbox))
+ cx = b[0] + offset[0]
+ cy = b[1] + offset[1]
+ cv2.putText(canvas, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, color)
+ return canvas
+
+ def draw_bboxes(self, canvas, dets, score_thr=0, show_score=True, thick=None, colors=(BasicLayout.colors['blue'])):
+ num_colors = len(colors)
+ for idx, bbox in enumerate(dets):
+ color = colors[idx % num_colors]
+ canvas = self.draw_bbox(canvas, bbox, score_thr=score_thr, show_score=show_score, thick=thick, color=color)
+ return canvas
diff --git a/SPIGA/spiga/demo/visualize/layouts/plot_headpose.py b/SPIGA/spiga/demo/visualize/layouts/plot_headpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed0a39858dc7a813c103dd51829feac79e53560f
--- /dev/null
+++ b/SPIGA/spiga/demo/visualize/layouts/plot_headpose.py
@@ -0,0 +1,70 @@
+import numpy as np
+import cv2
+
+# Demo libs
+from spiga.demo.visualize.layouts.plot_basics import BasicLayout
+
+
+class HeadposeLayout(BasicLayout):
+
+ BasicLayout.thickness_dft['hpose'] = 2
+
+ def __init__(self):
+ super().__init__()
+ self.hpose_axe_length = 2
+ self.focal_ratio = 1
+
+ def draw_headpose(self, canvas, bbox, rot, trl, euler=False, len_axe=None, thick=None,
+ colors=(BasicLayout.colors['blue'], BasicLayout.colors['green'], BasicLayout.colors['red'])):
+
+ trl = np.float32(trl)
+ rot = np.float32(rot)
+ K = self._camera_matrix(bbox)
+
+ # Init variables if need it
+ if len_axe is None:
+ len_axe = self.hpose_axe_length
+ if thick is None:
+ thick = self.thickness['hpose']
+
+ if euler:
+ rot = self._euler_to_rotation_matrix(rot)
+
+ rotV, _ = cv2.Rodrigues(rot)
+ points = np.float32([[len_axe, 0, 0], [0, -len_axe, 0], [0, 0, -len_axe], [0, 0, 0]]).reshape(-1, 3)
+ axisPoints, _ = cv2.projectPoints(points, rotV, trl, K, (0, 0, 0, 0))
+ canvas = cv2.line(canvas, tuple(axisPoints[3].ravel().astype(int)), tuple(axisPoints[2].ravel().astype(int)), colors[0], thick)
+ canvas = cv2.line(canvas, tuple(axisPoints[3].ravel().astype(int)), tuple(axisPoints[1].ravel().astype(int)), colors[1], thick)
+ canvas = cv2.line(canvas, tuple(axisPoints[3].ravel().astype(int)), tuple(axisPoints[0].ravel().astype(int)), colors[2], thick)
+ return canvas
+
+ @staticmethod
+ def _euler_to_rotation_matrix(headpose):
+ # http://euclideanspace.com/maths/geometry/rotations/conversions/eulerToMatrix/index.htm
+ # Change coordinates system
+ euler = np.array([-(headpose[0] - 90), -headpose[1], -(headpose[2] + 90)])
+ # Convert to radians
+ rad = euler * (np.pi / 180.0)
+ cy = np.cos(rad[0])
+ sy = np.sin(rad[0])
+ cp = np.cos(rad[1])
+ sp = np.sin(rad[1])
+ cr = np.cos(rad[2])
+ sr = np.sin(rad[2])
+ Ry = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]]) # yaw
+ Rp = np.array([[cp, -sp, 0.0], [sp, cp, 0.0], [0.0, 0.0, 1.0]]) # pitch
+ Rr = np.array([[1.0, 0.0, 0.0], [0.0, cr, -sr], [0.0, sr, cr]]) # roll
+ return np.matmul(np.matmul(Ry, Rp), Rr)
+
+ def _camera_matrix(self, bbox):
+ x1, y1, x2, y2 = bbox[:4]
+ w = x2-x1
+ h = y2-y1
+ focal_length_x = w * self.focal_ratio
+ focal_length_y = h * self.focal_ratio
+ face_center = (x1 + (w * 0.5)), (y1 + (h * 0.5))
+
+ cam_matrix = np.array([[focal_length_x, 0, face_center[0]],
+ [0, focal_length_y, face_center[1]],
+ [0, 0, 1]], dtype=np.float32)
+ return cam_matrix
\ No newline at end of file
diff --git a/SPIGA/spiga/demo/visualize/layouts/plot_landmarks.py b/SPIGA/spiga/demo/visualize/layouts/plot_landmarks.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d91432a024ba4071166b0d43fb9f2f020640641
--- /dev/null
+++ b/SPIGA/spiga/demo/visualize/layouts/plot_landmarks.py
@@ -0,0 +1,53 @@
+import numpy as np
+
+# Demo libs
+from spiga.demo.visualize.layouts.plot_basics import BasicLayout
+
+
+class LandmarkLayout(BasicLayout):
+
+ BasicLayout.thickness_dft['lnd'] = 3
+
+ def __init__(self):
+ super().__init__()
+
+ def draw_landmarks(self, image, landmarks, visible=None, mask=None,
+ thick=None, colors=(BasicLayout.colors['green'], BasicLayout.colors['red'])):
+
+ # Initialize variables if need it
+ if visible is None:
+ visible = np.ones(len(landmarks))
+ if mask is None:
+ mask = np.ones(len(landmarks))
+ if thick is None:
+ thick = self.thickness['lnd']
+
+ if isinstance(landmarks, (list, tuple)):
+ landmarks = np.array(landmarks)
+ if isinstance(visible, (list, tuple)):
+ visible = np.array(visible)
+ if isinstance(mask, (list, tuple)):
+ mask = np.array(mask)
+
+ # Clean and split landmarks
+ ldm_vis, ldm_notvis = self._split_lnd_by_vis(landmarks, visible, mask)
+
+ # PIL images to OpenCV
+ if image.shape[0] == 3:
+ image = image.transpose(1, 2, 0)
+
+ # Plot landmarks
+ canvas = self.draw_circles(image, ldm_vis, color=colors[0], thick=thick)
+ canvas = self.draw_circles(canvas, ldm_notvis, color=colors[1], thick=thick)
+ return canvas
+
+ @ staticmethod
+ def _split_lnd_by_vis(landmarks, visible, mask):
+ mask = np.array(mask, dtype=bool)
+ visible = np.array(visible, dtype=bool)
+ landmarks = landmarks[mask]
+ visible = visible[mask]
+ ldm_vis = landmarks[visible]
+ not_visible = np.logical_not(visible)
+ ldm_notvis = landmarks[not_visible]
+ return ldm_vis, ldm_notvis
diff --git a/SPIGA/spiga/demo/visualize/plotter.py b/SPIGA/spiga/demo/visualize/plotter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e960ba8f4f1804495added79fd5cb3a2131d2cbf
--- /dev/null
+++ b/SPIGA/spiga/demo/visualize/plotter.py
@@ -0,0 +1,14 @@
+# Demo libs
+import spiga.demo.visualize.layouts.plot_basics as pl_basic
+import spiga.demo.visualize.layouts.plot_bbox as pl_bbox
+import spiga.demo.visualize.layouts.plot_landmarks as pl_lnd
+import spiga.demo.visualize.layouts.plot_headpose as pl_hpose
+
+
+class Plotter:
+
+ def __init__(self):
+ self.basic = pl_basic.BasicLayout()
+ self.bbox = pl_bbox.BboxLayout()
+ self.landmarks = pl_lnd.LandmarkLayout()
+ self.hpose = pl_hpose.HeadposeLayout()
diff --git a/SPIGA/spiga/demo/visualize/viewer.py b/SPIGA/spiga/demo/visualize/viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa96e70a2d08ff13946ecedcd50342e3ae2cfca5
--- /dev/null
+++ b/SPIGA/spiga/demo/visualize/viewer.py
@@ -0,0 +1,163 @@
+import os
+import cv2
+import copy
+import time
+import numpy as np
+
+# Demo libs
+import spiga.demo.visualize.plotter as plt
+
+
+class Viewer:
+
+ def __init__(self, window_title, width=None, height=None, fps=30):
+ """
+ Initialization of the viewer canvas using width and height in pixels
+ :param window_title: The string with the window title to display.
+ :param width: The given width in pixels of the window canvas.
+ :param height: The given height in pixels of the window canvas.
+ :param fps: Frames per second
+ """
+ # Visualizer parameters
+ self.canvas = None
+ self.width = width
+ self.height = height
+ self.window_title = window_title
+ self.visualize = False
+
+ # Time variables
+ self.fps = fps
+ self.fps_inference = 0
+ self.fps_mean = 0
+ self.fps_lifo = np.zeros(self.fps)
+ self.timer = time.time()
+ self.frame_cnt = -1
+
+ # Video/Image writer
+ self.write = False
+ self.video_name = window_title # Initial name
+ self.video_path = None
+ self.video_writer = None
+
+ # Plots
+ self.plotter = plt.Plotter()
+ self.fps_draw_params = {'text_size': 0.75,
+ 'text_thick': 2,
+ 'coord': (10, 50),
+ 'font': cv2.FONT_HERSHEY_SIMPLEX,
+ 'color': (255, 255, 255)}
+
+ def start_view(self):
+ self._kill_window()
+ cv2.namedWindow(self.window_title)
+ self.visualize = True
+
+ def record_video(self, video_path, video_name=None):
+ self.write = True
+ if video_name is not None:
+ self.video_name = video_name
+ self.video_path = video_path
+ if not os.path.exists(video_path):
+ os.makedirs(video_path)
+
+ file_name = os.path.join(self.video_path, self.video_name + '.mp4')
+ self.video_writer = cv2.VideoWriter(file_name, cv2.VideoWriter_fourcc(*'MP4V'),
+ self.fps, (int(self.width), int(self.height)))
+
+ def save_canvas(self, file_path=None):
+ if file_path is None:
+ if self.video_path is None:
+ raise ValueError('Path not defined neither video_path is available')
+ else:
+ file_path = self.video_path
+
+ file_name = os.path.join(file_path, '/%s_%i.jpg' % (self.video_name, self.frame_cnt))
+ cv2.imwrite(file_path + file_name, self.canvas)
+
+ def reset_params(self, width, height, window_title, fps=30):
+ self.width = width
+ self.height = height
+ self._kill_window()
+ if self.video_name == self.window_title:
+ self.video_name = window_title
+ self.window_title = window_title
+ self.fps = fps
+
+ def close(self):
+ if self.write:
+ self.video_writer.release()
+ self._kill_window()
+
+ def process_image(self, input_img, drawers=(), show_attributes=('fps')):
+
+ # Variables
+ image = copy.copy(input_img)
+ img_h, img_w, img_ch = image.shape
+
+ # Convert gray scale image to color if needed
+ if img_ch == 1:
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+
+ # Draw features on image
+ image = self._draw_over_canvas(image, drawers, show_attributes)
+
+ # Resize image if needed to canvas shape
+ if img_w != self.width or img_h != self.height:
+ image = cv2.resize(image, (self.width, self.height))
+
+ # Update canvas
+ self.canvas = image
+ # Visualize FPS
+ if 'fps' in show_attributes:
+ self._plot_fps()
+
+ # Write the resulting frame
+ if self.write:
+ self.video_writer.write(self.canvas)
+
+ # Timing loop variables
+ loop_time = self._update_timers()
+ break_flag = False
+ # Visualization
+ if self.visualize:
+ cv2.imshow(self.window_title, self.canvas)
+ sleep_time = int(1000 * (1 / self.fps - loop_time))
+ if sleep_time <= 0:
+ sleep_time = 1
+ if cv2.waitKey(sleep_time) & 0xFF == ord('q'):
+ break_flag = True
+
+ self.timer = time.time()
+ return break_flag
+
+ def _plot_fps(self):
+ # Plot algorithm time
+ params = self.fps_draw_params
+ cv2.putText(self.canvas, ('FPS: %.2f' % self.fps_mean), params['coord'], params['font'], params['text_size'],
+ params['color'], params['text_thick'], cv2.LINE_AA)
+
+ def _draw_over_canvas(self, image, drawers, show_attributes):
+ for drawer in drawers:
+ image = drawer.plot_features(image, self.plotter, show_attributes)
+ return image
+
+ def _kill_window(self):
+ self.visualize = False
+ try:
+ cv2.destroyWindow(self.window_title)
+ except:
+ pass
+
+ def _update_timers(self):
+ self.frame_cnt += 1
+ loop_time = time.time() - self.timer
+ self.fps_inference = 1/loop_time
+ lifo_idx = self.frame_cnt % self.fps
+ self.fps_lifo[lifo_idx] = self.fps_inference
+ if lifo_idx == 0:
+ self.fps_mean = np.mean(self.fps_lifo)
+ return loop_time
+
+
+
+
diff --git a/SPIGA/spiga/eval/__init__.py b/SPIGA/spiga/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/eval/benchmark/__init__.py b/SPIGA/spiga/eval/benchmark/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/eval/benchmark/evaluator.py b/SPIGA/spiga/eval/benchmark/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..85ed64a7cc6cd3be4b4ed3fad5f1aa9ce1613c1d
--- /dev/null
+++ b/SPIGA/spiga/eval/benchmark/evaluator.py
@@ -0,0 +1,108 @@
+import json
+import pkg_resources
+from collections import OrderedDict
+
+# Paths
+data_path = pkg_resources.resource_filename('spiga', 'data/annotations')
+
+def main():
+
+ import argparse
+ pars = argparse.ArgumentParser(description='Benchmark alignments evaluator')
+ pars.add_argument('pred_file', nargs='+', type=str, help='Absolute path to the prediction json file (Multi file)')
+ pars.add_argument('--eval', nargs='+', type=str, default=['lnd'],
+ choices=['lnd', 'pose'], help='Evaluation modes')
+ pars.add_argument('-s', '--save', action='store_true', help='Save results')
+ args = pars.parse_args()
+
+ for pred_file in args.pred_file:
+ benchmark = get_evaluator(pred_file, args.eval, args.save)
+ benchmark.metrics()
+
+
+class Evaluator:
+
+ def __init__(self, data_file, evals=(), save=True, process_err=True):
+
+ # Inputs
+ self.data_file = data_file
+ self.evals = evals
+ self.save = save
+
+ # Paths
+ data_name = data_file.split('/')[-1]
+ self.data_dir = data_file.split(data_name)[0]
+
+ # Information from name
+ data_name = data_name.split('.')[0]
+ data_name = data_name.split('_')
+ self.data_type = data_name[-1]
+ self.database = data_name[-2]
+
+ # Load predictions and annotations
+ anns_file = data_path + '/%s/%s.json' % (self.database, self.data_type)
+ self.anns = self.load_files(anns_file)
+ self.pred = self.load_files(data_file)
+
+ # Compute errors
+ self.error = OrderedDict()
+ self.error_pimg = OrderedDict()
+ self.metrics_log = OrderedDict()
+ if process_err:
+ self.compute_error(self.anns, self.pred)
+
+ def compute_error(self, anns, pred, select_ids=None):
+ database_ref = [self.database, self.data_type]
+ for eval in self.evals:
+ self.error[eval.name] = eval.compute_error(anns, pred, database_ref, select_ids)
+ self.error_pimg = eval.get_pimg_err(self.error_pimg)
+ return self.error
+
+ def metrics(self):
+ for eval in self.evals:
+ self.metrics_log[eval.name] = eval.metrics()
+
+ if self.save:
+ file_name = self.data_dir + '/metrics_%s_%s.txt' % (self.database, self.data_type)
+ with open(file_name, 'w') as file:
+ file.write(str(self))
+
+ return self.metrics_log
+
+ def load_files(self, input_file):
+ with open(input_file) as jsonfile:
+ data = json.load(jsonfile)
+ return data
+
+ def _dict2text(self, name, dictionary, num_tab=1):
+ prev_tabs = '\t'*num_tab
+ text = '%s {\n' % name
+ for k, v in dictionary.items():
+ if isinstance(v, OrderedDict) or isinstance(v, dict):
+ text += '{}{}'.format(prev_tabs, self._dict2text(k, v, num_tab=num_tab+1))
+ else:
+ text += '{}{}: {}\n'.format(prev_tabs, k, v)
+ text += (prev_tabs + '}\n')
+ return text
+
+ def __str__(self):
+ state_dict = self.metrics_log
+ text = self._dict2text('Metrics', state_dict)
+ return text
+
+
+def get_evaluator(pred_file, evaluate=('lnd', 'pose'), save=False, process_err=True):
+ eval_list = []
+ if "lnd" in evaluate:
+ import spiga.eval.benchmark.metrics.landmarks as mlnd
+ eval_list.append(mlnd.MetricsLandmarks())
+ if "pose" in evaluate:
+ import spiga.eval.benchmark.metrics.pose as mpose
+ eval_list.append(mpose.MetricsHeadpose())
+
+ return Evaluator(pred_file, evals=eval_list, save=save, process_err=process_err)
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/SPIGA/spiga/eval/benchmark/metrics/__init__.py b/SPIGA/spiga/eval/benchmark/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/eval/benchmark/metrics/landmarks.py b/SPIGA/spiga/eval/benchmark/metrics/landmarks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6201394d4108ff26f20b018aa781de5df846f671
--- /dev/null
+++ b/SPIGA/spiga/eval/benchmark/metrics/landmarks.py
@@ -0,0 +1,236 @@
+import os
+import numpy as np
+import json
+from collections import OrderedDict
+from scipy.integrate import simps
+
+from spiga.data.loaders.dl_config import db_anns_path
+from spiga.eval.benchmark.metrics.metrics import Metrics
+
+
+class MetricsLandmarks(Metrics):
+
+ def __init__(self, name='landmarks'):
+ super().__init__(name)
+
+ self.db_info = None
+ self.nme_norm = "corners"
+ self.nme_thr = 8
+ self.percentile = [90, 95, 99]
+ # Cumulative plot axis length
+ self.bins = 10000
+
+ def compute_error(self, data_anns, data_pred, database, select_ids=None):
+
+ # Initialize global logs and variables of Computer Error function
+ self.init_ce(data_anns, data_pred, database)
+ self._update_lnd_param()
+
+ # Order data and compute nme
+ self.error['nme_per_img'] = []
+ self.error['ne_per_img'] = OrderedDict()
+ self.error['ne_per_ldm'] = OrderedDict()
+ for img_id, anns in enumerate(data_anns):
+ # Init variables per img
+ pred = data_pred[img_id]
+
+ # Get select ids to compute
+ if select_ids is None:
+ selected_ldm = anns['ids']
+ else:
+ selected_ldm = list(set(select_ids) & set(anns['ids']))
+
+ norm = self._get_img_norm(anns)
+ for ldm_id in selected_ldm:
+ # Compute Normalize Error
+ anns_ldm = self._get_lnd_from_id(anns, ldm_id)
+ pred_ldm = self._get_lnd_from_id(pred, ldm_id)
+ ne = self._dist_l2(anns_ldm, pred_ldm)/norm * 100
+ self.error['ne_per_img'].setdefault(img_id, []).append(ne)
+ self.error['ne_per_ldm'].setdefault(ldm_id, []).append(ne)
+
+ # NME per image
+ if self.database in ['merlrav']:
+ # LUVLI at MERLRAV divide by 68 despite the annotated landmarks in the image.
+ self.error['nme_per_img'].append(np.sum(self.error['ne_per_img'][img_id])/68)
+ else:
+ self.error['nme_per_img'].append(np.mean(self.error['ne_per_img'][img_id]))
+
+ # Cumulative NME
+ self.error['cumulative_nme'] = self._cumulative_error(self.error['nme_per_img'], bins=self.bins)
+
+ return self.error
+
+ def metrics(self):
+
+ # Initialize global logs and variables of Metrics function
+ self.init_metrics()
+
+ # Basic metrics (NME/NMPE/AUC/FR) for full dataset
+ nme, nmpe, auc, fr, _, _ = self._basic_metrics()
+
+ print('NME: %.3f' % nme)
+ self.metrics_log['nme'] = nme
+ for percent_id, percentile in enumerate(self.percentile):
+ print('NME_P%i: %.3f' % (percentile, nmpe[percent_id]))
+ self.metrics_log['nme_p%i' % percentile] = nmpe[percent_id]
+ self.metrics_log['nme_thr'] = self.nme_thr
+ self.metrics_log['nme_norm'] = self.nme_norm
+ print('AUC_%i: %.3f' % (self.nme_thr, auc))
+ self.metrics_log['auc'] = auc
+ print('FR_%i: %.3f' % (self.nme_thr, fr))
+ self.metrics_log['fr'] = fr
+
+ # Subset basic metrics
+ subsets = self.db_info['test_subsets']
+ if self.data_type == 'test' and len(subsets) > 0:
+ self.metrics_log['subset'] = OrderedDict()
+ for subset, img_filter in subsets.items():
+ self.metrics_log['subset'][subset] = OrderedDict()
+ nme, nmpe, auc, fr, _, _ = self._basic_metrics(img_select=img_filter)
+ print('> Landmarks subset: %s' % subset.upper())
+ print('NME: %.3f' % nme)
+ self.metrics_log['subset'][subset]['nme'] = nme
+ for percent_id, percentile in enumerate(self.percentile):
+ print('NME_P%i: %.3f' % (percentile, nmpe[percent_id]))
+ self.metrics_log['subset'][subset]['nme_p%i' % percentile] = nmpe[percent_id]
+ print('AUC_%i: %.3f' % (self.nme_thr, auc))
+ self.metrics_log['subset'][subset]['auc'] = auc
+ print('FR_%i: %.3f' % (self.nme_thr, fr))
+ self.metrics_log['subset'][subset]['fr'] = fr
+
+ # NME/NPE per landmark
+ self.metrics_log['nme_per_ldm'] = OrderedDict()
+ for percentile in self.percentile:
+ self.metrics_log['npe%i_per_ldm' % percentile] = OrderedDict()
+ for k, v in self.error['ne_per_ldm'].items():
+ self.metrics_log['nme_per_ldm'][k] = np.mean(v)
+ for percentile in self.percentile:
+ self.metrics_log['npe%i_per_ldm' % percentile][k] = np.percentile(v, percentile)
+
+ return self.metrics_log
+
+ def get_pimg_err(self, data_dict=None, img_select=None):
+ data = self.error['nme_per_img']
+ if img_select is not None:
+ data = [data[img_id] for img_id in img_select]
+ name_dict = self.name + '/nme'
+ if data_dict is not None:
+ data_dict[name_dict] = data
+ else:
+ data_dict = data
+ return data_dict
+
+ def _update_lnd_param(self):
+ db_info_file = db_anns_path.format(database=self.database, file_name='db_info')
+ if os.path.exists(db_info_file):
+ with open(db_info_file) as jsonfile:
+ self.db_info = json.load(jsonfile)
+
+ norm_dict = self.db_info['norm']
+ nme_norm, nme_thr = next(iter(norm_dict.items()))
+ print('Default landmarks configuration: \n %s: %i' % (nme_norm, nme_thr))
+ answer = input("Change default config? (N/Y) >>> ")
+ if answer.lower() in ['yes', 'y']:
+ answer = input("Normalization options: %s >>> " % str(list(norm_dict.keys())))
+ if answer in norm_dict.keys():
+ nme_norm = answer
+ nme_thr = norm_dict[nme_norm]
+ else:
+ print("Option %s not available keep in default one: %s" % (answer, nme_norm))
+ answer = input("Change threshold ->%s:%i ? (N/Y) >>> " % (nme_norm, nme_thr))
+ if answer.lower() in ['yes', 'y']:
+ answer = input('NME threshold: >>> ')
+ nme_thr = float(answer)
+ else:
+ print("Keeping default threshold: %i" % nme_thr)
+
+ self.nme_norm = nme_norm
+ self.nme_thr = nme_thr
+
+ else:
+ raise ValueError('Database %s specifics not defined. Missing db_info.json' % self.database)
+
+ def _dist_l2(self, pointA, pointB):
+ return float(((pointA - pointB) ** 2).sum() ** 0.5)
+
+ def _get_lnd_from_id(self, anns, ids):
+ idx = anns['ids'].index(ids)
+ ref = np.array(anns['landmarks'][idx])
+ return ref
+
+ def _get_img_norm(self, anns):
+ if self.nme_norm == 'pupils':
+ print('WARNING: Pupils norm only implemented for 68 landmark configuration')
+ left_eye = [7, 138, 139, 8, 141, 142]
+ right_eye = [11, 144, 145, 12, 147, 148]
+ refA = np.zeros(2)
+ refB = np.zeros(2)
+ for i in range(len(left_eye)):
+ refA += self._get_lnd_from_id(anns, left_eye[i])
+ refB += self._get_lnd_from_id(anns, right_eye[i])
+ refA = refA/len(left_eye) # Left
+ refB = refB/len(right_eye) # Right
+ elif self.nme_norm == 'corners':
+ refA = self._get_lnd_from_id(anns, 12) # Left
+ refB = self._get_lnd_from_id(anns, 7) # Right
+ elif self.nme_norm == 'diagonal':
+ refA = anns['bbox'][0:2]
+ refB = refA + anns['bbox'][2:4]
+ elif self.nme_norm == 'height':
+ return anns['bbox'][3]
+ elif self.nme_norm == 'lnd_bbox':
+ lnd = np.array(anns['landmarks'])
+ lnd_max = np.max(lnd, axis=0)
+ lnd_min = np.min(lnd, axis=0)
+ lnd_wh = lnd_max - lnd_min
+ return (lnd_wh[0]*lnd_wh[1])**0.5
+ elif self.nme_norm == 'bbox':
+ return (anns['bbox'][2] * anns['bbox'][3]) ** 0.5
+ else:
+ raise ValueError('Normalization %s not implemented' % self.nme_norm)
+
+ return self._dist_l2(refA, refB)
+
+ def _cumulative_error(self, error, bins=10000):
+ num_imgs, base = np.histogram(error, bins=bins)
+ cumulative = [x / float(len(error)) for x in np.cumsum(num_imgs)]
+ base = base[:bins]
+ cumulative, base = self._filter_cumulative(cumulative, base)
+ return [cumulative, base]
+
+ def _filter_cumulative(self, cumulative, base):
+ base = [x for x in base if (x < self.nme_thr)]
+ cumulative = cumulative[:len(base)]
+ return cumulative, base
+
+ def _basic_metrics(self, img_select=None):
+ data = self.error['nme_per_img']
+ if img_select is not None:
+ data = [data[img_id] for img_id in img_select]
+ [cumulative, base] = self._cumulative_error(data, bins=self.bins)
+ else:
+ [cumulative, base] = self.error['cumulative_nme']
+
+ # Normalize Mean Error across img
+ nme = np.mean(data)
+ # Normalize Mean Percentile Error across img
+ nmpe = []
+ for percentile in self.percentile:
+ nmpe.append(np.percentile(data, percentile))
+
+ # Area Under Curve and Failure Rate
+ auc, fr = self._auc_fr_metrics(cumulative, base)
+
+ return nme, nmpe, auc, fr, cumulative, base
+
+ def _auc_fr_metrics(self, cumulative, base):
+ if not base:
+ auc = 0.
+ fr = 100.
+ else:
+ auc = (simps(cumulative, x=base) / self.nme_thr) * 100.0
+ if base[-1] < self.nme_thr and cumulative[-1] == 1:
+ auc += ((self.nme_thr - base[-1]) / self.nme_thr) * 100
+ fr = (1 - cumulative[-1]) * 100.0
+ return auc, fr
diff --git a/SPIGA/spiga/eval/benchmark/metrics/metrics.py b/SPIGA/spiga/eval/benchmark/metrics/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f690ae75ce39e44e3219a898c3acee6928a4ab
--- /dev/null
+++ b/SPIGA/spiga/eval/benchmark/metrics/metrics.py
@@ -0,0 +1,42 @@
+from collections import OrderedDict
+
+
+class Metrics:
+
+ def __init__(self, name='metrics'):
+
+ # Data dicts
+ self.error = OrderedDict()
+ self.metrics_log = OrderedDict()
+ self.name = name
+ self.database = None
+ self.data_type = None
+
+ def compute_error(self, data_anns, data_pred, database, select_ids=None):
+ self.init_ce(data_anns, data_pred, database)
+ raise ValueError('Computer error has to be implemented by inheritance')
+
+ def init_ce(self, data_anns, data_pred, database):
+ # Update database info
+ [self.database, self.data_type] = database
+ # Logs and checks
+ print('Computing %s error...' % self.name)
+ if len(data_anns) == 0:
+ raise ValueError('Annotations miss for computing error in %s' % self.name)
+ if len(data_pred) == 0:
+ raise ValueError('Predictions miss for computing error in %s' % self.name)
+ elif len(data_pred) != len(data_anns):
+ raise Warning('Prediction vs annotations length mismatch')
+
+ def metrics(self):
+ self.init_metrics()
+ raise ValueError('Metrics has to be implemented by inheritance')
+
+ def init_metrics(self):
+ # Logs and checks
+ print('> Metrics %s:' % self.name)
+ if len(self.error) == 0:
+ raise ValueError('Error must be compute first in %s' % self.name)
+
+ def get_pimg_err(self, data_dict):
+ return data_dict
diff --git a/SPIGA/spiga/eval/benchmark/metrics/pose.py b/SPIGA/spiga/eval/benchmark/metrics/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e591f00f71f26466d66545a5229663f116193eb
--- /dev/null
+++ b/SPIGA/spiga/eval/benchmark/metrics/pose.py
@@ -0,0 +1,159 @@
+import numpy as np
+from sklearn.metrics import confusion_matrix
+
+from spiga.eval.benchmark.metrics.metrics import Metrics
+
+
+class MetricsHeadpose(Metrics):
+
+ def __init__(self, name='headpose'):
+ super().__init__(name)
+
+ # Angles
+ self.angles = ['yaw', 'pitch', 'roll']
+ # Confusion matrix intervals
+ self.pose_labels = [-90, -75, -60, -45, -30, -15, 0, 15, 30, 45, 60, 75, 90]
+ # Percentile reference angles
+ self.error_labels = [2.5, 5, 10, 15, 30]
+ # Cumulative plot axis length
+ self.bins = 1000
+
+ def compute_error(self, data_anns, data_pred, database, select_ids=None):
+
+ # Initialize global logs and variables of Computer Error function
+ self.init_ce(data_anns, data_pred, database)
+
+ # Generate annotations if needed
+ if data_anns[0]['headpose'] is None:
+ print('Database anns generated by posit...')
+ data_anns = self._posit_anns()
+ print('Posit generation done...')
+
+ # Dictionary variables
+ self.error['data_pred'] = []
+ self.error['data_anns'] = []
+ self.error['data_pred_trl'] = []
+ self.error['data_anns_trl'] = []
+ self.error['mae_ypr'] = []
+ self.error['mae_mean'] = []
+
+ # Order data
+ for img_id, img_anns in enumerate(data_anns):
+ pose_anns = img_anns['headpose'][0:3]
+ self.error['data_anns'].append(pose_anns)
+ pose_pred = data_pred[img_id]['headpose'][0:3]
+ self.error['data_pred'].append(pose_pred)
+
+ # Compute MAE error
+ anns_array = np.array(self.error['data_anns'])
+ pred_array = np.array(self.error['data_pred'])
+ mae_ypr = np.abs((anns_array-pred_array))
+ self.error['mae_ypr'] = mae_ypr.tolist()
+ self.error['mae_mean'] = np.mean(mae_ypr, axis=-1).tolist()
+
+ # Quantize labeled data
+ label_anns = self._nearest_label(anns_array)
+ label_pred = self._nearest_label(pred_array)
+ self.error['label_anns'] = label_anns
+ self.error['label_pred'] = label_pred
+
+ for angle_id, angle in enumerate(self.angles):
+ # Confusion matrix
+ self.error['cm_%s' % angle] = confusion_matrix(label_anns[:, angle_id], label_pred[:, angle_id])
+ # Cumulative error
+ self.error['cumulative_%s' % angle] = self._cumulative_error(mae_ypr[:, angle_id], bins=self.bins)
+
+ return self.error
+
+ def metrics(self):
+
+ # Initialize global logs and variables of Metrics function
+ self.init_metrics()
+
+ # Mean Absolute Error
+ mae_ypr = np.array(self.error['mae_ypr'])
+ mae_ypr_mean = np.mean(mae_ypr, axis=0)
+ self.metrics_log['mae_ypr'] = mae_ypr_mean.tolist()
+ self.metrics_log['mae_mean'] = np.mean(mae_ypr_mean)
+ print('MAE [yaw, pitch, roll]: [%.3f, %.3f, %.3f]' % (mae_ypr_mean[0], mae_ypr_mean[1], mae_ypr_mean[2]))
+ print('MAE mean: %.3f' % self.metrics_log['mae_mean'])
+
+ # Per angle measurements
+ self.metrics_log['acc_label'] = []
+ self.metrics_log['acc_adj_label'] = []
+
+ for angle_id, angle in enumerate(self.angles):
+
+ # Accuracy per label
+ cm = self.error['cm_%s' % angle]
+ diagonal = np.diagonal(cm, offset=0).sum()
+ acc_main = diagonal / cm.sum().astype('float')
+ self.metrics_log['acc_label'].append(acc_main)
+
+ # Permissive accuracy
+ diagonal_adj = diagonal.sum() + np.diagonal(cm, offset=-1).sum() + np.diagonal(cm, offset=1).sum()
+ acc_adj = diagonal_adj / cm.sum().astype('float')
+ self.metrics_log['acc_adj_label'].append(acc_adj)
+
+ # Percentile of relevant angles
+ self.metrics_log['sr_%s' % angle] = {}
+ for angle_num in self.error_labels:
+ if max(mae_ypr[:, angle_id]) > angle_num:
+ [cumulative, base] = self.error['cumulative_%s' % angle]
+ perc = [cumulative[x[0] - 1] for x in enumerate(base) if x[1] > angle_num][0]
+ else:
+ perc = 1.
+
+ self.metrics_log['sr_%s' % angle][angle_num] = perc
+
+ print('Accuracy [yaw, pitch, roll]: ', self.metrics_log['acc_label'])
+ print('Accuracy [yaw, pitch, roll] (adjacency as TP): ', self.metrics_log['acc_adj_label'])
+ for angle in self.angles:
+ print('Success Rate %s: ' % angle, self.metrics_log['sr_%s' % angle])
+
+ return self.metrics_log
+
+ def get_pimg_err(self, data_dict, img_select=None):
+ mae_mean = self.error['mae_mean']
+ mae_ypr = self.error['mae_ypr']
+ if img_select is not None:
+ mae_mean = [mae_mean[img_id] for img_id in img_select]
+ mae_ypr = [mae_ypr[img_id] for img_id in img_select]
+ name_dict = self.name + '/%s'
+ data_dict[name_dict % 'mae'] = mae_mean
+ mae_ypr = np.array(mae_ypr)
+ data_dict[name_dict % 'mae_yaw'] = mae_ypr[:, 0].tolist()
+ data_dict[name_dict % 'mae_pitch'] = mae_ypr[:, 1].tolist()
+ data_dict[name_dict % 'mae_roll'] = mae_ypr[:, 2].tolist()
+ return data_dict
+
+ def _posit_anns(self):
+
+ import spiga.data.loaders.dl_config as dl_config
+ import spiga.data.loaders.dataloader as dl
+
+ # Load configuration
+ data_config = dl_config.AlignConfig(self.database, self.data_type)
+ data_config.image_size = (256, 256)
+ data_config.generate_pose = True
+ data_config.aug_names = []
+ data_config.shuffle = False
+ dataloader, _ = dl.get_dataloader(1, data_config, debug=True)
+
+ data_anns = []
+ for num_batch, batch_dict in enumerate(dataloader):
+ pose = batch_dict['pose'].numpy()
+ data_anns.append({'headpose': pose[0].tolist()})
+ return data_anns
+
+ def _nearest_label(self, data):
+ data_tile = data[:, :, np.newaxis]
+ data_tile = np.tile(data_tile, len(self.pose_labels))
+ diff_tile = np.abs(data_tile - self.pose_labels)
+ label_idx = diff_tile.argmin(axis=-1)
+ return label_idx
+
+ def _cumulative_error(self, error, bins=1000):
+ num_imgs, base = np.histogram(error, bins=bins)
+ cumulative = [x / float(len(error)) for x in np.cumsum(num_imgs)]
+ return [cumulative[:bins], base[:bins]]
diff --git a/SPIGA/spiga/eval/benchmark/readme.md b/SPIGA/spiga/eval/benchmark/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..d4cb7c4d266816ef48e7ecec2e1da07f4d6f4c68
--- /dev/null
+++ b/SPIGA/spiga/eval/benchmark/readme.md
@@ -0,0 +1,24 @@
+# SPIGA: Benchmark
+The benchmark evaluator can be found at ```./eval/benchmark/evaluator.py``` and it allows
+to extract an extended report of metrics for each dataset. For further details,
+check the parser and complete the interactive terminal procedure to specify the evaluation
+characteristics.
+
+In order to use the benchmark evaluation, the prediction file must follow the same data structure
+and file extension as the ground truth annotations available in ```./data/annotations/```.
+The data structure consist on a list of dictionaries where each one represents an image sample,
+similar to the previous dataloader configuration:
+
+```
+sample = {"imgpath": Relative image path,
+ "bbox": Bounding box [x,y,w,h] (ref image),
+ "headpose": Euler angles [yaw, pitch, roll],
+ "ids": Landmarks database ids,
+ "landmarks": Landmarks (ref image),
+ "visible": Visibilities [0,1, ...] (1 == Visible)
+ }
+```
+
+Finally, is worth to mention that the benchmark can be easily extended for other task by
+inheriting the class structure available in ```./eval/benchmark/metrics/metrics.py``` and
+developing a new task file like the available ones: landmarks and headpose.
diff --git a/SPIGA/spiga/eval/results/300wprivate/metrics_300wprivate_test.txt b/SPIGA/spiga/eval/results/300wprivate/metrics_300wprivate_test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..132752a14125c287f1d0a664084ce4943a4b945c
--- /dev/null
+++ b/SPIGA/spiga/eval/results/300wprivate/metrics_300wprivate_test.txt
@@ -0,0 +1,337 @@
+Metrics {
+ landmarks {
+ nme: 2.030598939445185
+ nme_p90: 2.7880932616458525
+ nme_p95: 3.0777544749713726
+ nme_p99: 3.834184107657207
+ nme_thr: 7
+ nme_norm: lnd_bbox
+ auc: 71.00849804168183
+ fr: 0.16666666666667052
+ subset {
+ indoor {
+ nme: 2.034564559904317
+ nme_p90: 2.7259855150065606
+ nme_p95: 3.0092119632411545
+ nme_p99: 3.713329759683323
+ auc: 70.96381253891056
+ fr: 0.33333333333332993
+ }
+ outdoor {
+ nme: 2.0266333189860526
+ nme_p90: 2.823638152094534
+ nme_p95: 3.217040770610454
+ nme_p99: 3.834184107657207
+ auc: 71.05339689254487
+ fr: 0.0
+ }
+ }
+ nme_per_ldm {
+ 101: 3.6806610803182505
+ 102: 3.3036602516363742
+ 103: 3.339433875945013
+ 104: 3.454702516496273
+ 105: 3.611916041455306
+ 106: 3.719966247876225
+ 107: 3.548076775082098
+ 108: 3.2518252416098865
+ 24: 3.096658896153719
+ 110: 3.331046674179755
+ 111: 3.699712378619379
+ 112: 3.7843353250872225
+ 113: 3.6034097769479234
+ 114: 3.398488809686842
+ 115: 3.3141217781436625
+ 116: 3.280323260898595
+ 117: 3.544361895970207
+ 1: 2.757092591472853
+ 119: 2.242685122483485
+ 2: 2.0915854137040806
+ 121: 2.0586213715102817
+ 3: 2.123266891475812
+ 4: 2.078486972366517
+ 124: 1.796939685624812
+ 5: 1.9188959395054488
+ 126: 2.1738624874122228
+ 6: 2.624636921552657
+ 128: 1.5520879573321433
+ 129: 1.2549002865469516
+ 130: 1.17669611982962
+ 17: 1.4660272904294709
+ 16: 1.4505750980043473
+ 133: 1.029206611825485
+ 134: 1.1578763764399964
+ 135: 1.4523188926457233
+ 18: 1.8149727805659837
+ 7: 1.388113293070523
+ 138: 0.9782280739323941
+ 139: 1.4519658473793906
+ 8: 1.9434955210694351
+ 141: 1.579927226333795
+ 142: 0.9795537257185432
+ 11: 1.573585552070739
+ 144: 1.009652358478732
+ 145: 1.2587210616231412
+ 12: 1.5793154047393527
+ 147: 1.1712298544072581
+ 148: 1.0158092454679306
+ 20: 1.3654116942883525
+ 150: 1.3991990474588623
+ 151: 1.2462856085922873
+ 22: 1.0382222181925866
+ 153: 1.4298055400286076
+ 154: 1.6178077717355814
+ 21: 1.9146205126976596
+ 156: 1.7453498742022029
+ 157: 1.6858832552899667
+ 23: 1.396568360761986
+ 159: 1.5292370812850333
+ 160: 1.5946255121363364
+ 161: 1.3332901947302789
+ 162: 1.1785447751601195
+ 163: 1.0187335950212426
+ 164: 1.3344807664948581
+ 165: 1.736265138539922
+ 166: 1.638976611962414
+ 167: 1.3432143685181699
+ 168: 1.4211731540502521
+ }
+ npe90_per_ldm {
+ 101: 7.792348947260428
+ 102: 7.084284783662423
+ 103: 7.00267244505699
+ 104: 7.11370360965358
+ 105: 7.169040113179254
+ 106: 7.2420085022388765
+ 107: 7.0092655347920685
+ 108: 6.653925085500781
+ 24: 6.312629764258624
+ 110: 6.454926282662611
+ 111: 7.414894684684534
+ 112: 7.8126541675676835
+ 113: 7.415456624957532
+ 114: 7.24509404695461
+ 115: 6.912281722736787
+ 116: 6.672501696370728
+ 117: 7.217817998736465
+ 1: 5.487776518419356
+ 119: 4.454286064030085
+ 2: 4.0134627911922
+ 121: 4.088715377295195
+ 3: 4.211052099955585
+ 4: 4.037415662963983
+ 124: 3.4469165471729477
+ 5: 3.9864537671903637
+ 126: 4.4751537271539075
+ 6: 5.474239719648658
+ 128: 2.834816765905374
+ 129: 2.2342495378788207
+ 130: 2.1759546821990923
+ 17: 2.5418198758079984
+ 16: 2.576833155505584
+ 133: 1.9794875969490342
+ 134: 2.098944179442805
+ 135: 2.407073180794855
+ 18: 2.9144263159065007
+ 7: 2.7047079934870015
+ 138: 1.8557917169522138
+ 139: 2.457276944950026
+ 8: 3.2924827521662205
+ 141: 2.704159715833272
+ 142: 1.8698923398185585
+ 11: 2.683163668302068
+ 144: 1.8304971688134921
+ 145: 2.184512515479075
+ 12: 2.912935303982953
+ 147: 2.074820991655545
+ 148: 1.8803985833925125
+ 20: 2.4937956106514334
+ 150: 2.591990835431721
+ 151: 2.2625583999418772
+ 22: 1.9780265158922163
+ 153: 2.6224550571489877
+ 154: 3.2408057855185093
+ 21: 3.707365699795642
+ 156: 3.1072768395751913
+ 157: 3.1679371463471115
+ 23: 2.510236952841339
+ 159: 2.8718913632783964
+ 160: 2.881251348014151
+ 161: 2.524050295380916
+ 162: 2.1275892848246936
+ 163: 1.884137823110829
+ 164: 2.432245754693854
+ 165: 3.337998651440683
+ 166: 2.9696817088203815
+ 167: 2.316776125494591
+ 168: 2.568779298951829
+ }
+ npe95_per_ldm {
+ 101: 10.075572675677341
+ 102: 8.818179377359794
+ 103: 8.782133525632883
+ 104: 8.643303594630462
+ 105: 8.606411141514197
+ 106: 8.641525601763867
+ 107: 8.045215924441003
+ 108: 8.435963345382614
+ 24: 8.98978588560972
+ 110: 9.209410203200584
+ 111: 9.690509154913633
+ 112: 9.528589159993023
+ 113: 9.516280184903218
+ 114: 8.806410992841911
+ 115: 8.349939939616558
+ 116: 8.306406512963349
+ 117: 8.948843864453119
+ 1: 6.607728795206713
+ 119: 5.328839886207676
+ 2: 4.719557597011315
+ 121: 4.803093906639568
+ 3: 5.066642833569373
+ 4: 5.106882789323109
+ 124: 4.405464649556331
+ 5: 4.792967300757355
+ 126: 5.312659675767809
+ 6: 6.55777924210191
+ 128: 3.309968146826916
+ 129: 2.6825883040009937
+ 130: 2.4823859073064334
+ 17: 3.107821014772654
+ 16: 3.104094496813284
+ 133: 2.2987283138903787
+ 134: 2.3965594095703406
+ 135: 2.778131597510065
+ 18: 3.417492367908234
+ 7: 3.620758651495707
+ 138: 2.5771490826818635
+ 139: 2.983064731141363
+ 8: 3.718558278089841
+ 141: 3.1060236607485288
+ 142: 2.382002241110809
+ 11: 3.1854035291469387
+ 144: 2.2942253082346777
+ 145: 2.7631832911247534
+ 12: 3.4466035286261287
+ 147: 2.6992500039970317
+ 148: 2.3485294437550133
+ 20: 3.216585381779816
+ 150: 3.0851452590040425
+ 151: 3.0592010658988906
+ 22: 2.4266778242787073
+ 153: 3.38165515276341
+ 154: 3.7832660988534497
+ 21: 4.581001942866445
+ 156: 3.8923444760487502
+ 157: 3.7494216498115063
+ 23: 3.3595885805903856
+ 159: 3.6741936975782754
+ 160: 3.5264543871919876
+ 161: 3.1694861030365487
+ 162: 2.515374327000464
+ 163: 2.2866964953427025
+ 164: 3.1082736052648565
+ 165: 4.031899925586421
+ 166: 3.578448574876227
+ 167: 2.875831262261589
+ 168: 3.232139784668087
+ }
+ npe99_per_ldm {
+ 101: 13.801902428443963
+ 102: 12.87712171046809
+ 103: 12.402656213930339
+ 104: 11.1057085602639
+ 105: 11.353922785698465
+ 106: 11.647547037817693
+ 107: 12.83429625061474
+ 108: 13.307141731654763
+ 24: 14.618925151278201
+ 110: 15.261628339432663
+ 111: 14.14823585987817
+ 112: 12.335533428069908
+ 113: 11.802116552825849
+ 114: 12.510726633232077
+ 115: 12.69435364188939
+ 116: 12.89529194227151
+ 117: 14.120910351994045
+ 1: 8.638311101530586
+ 119: 7.535269476509163
+ 2: 6.559870102583517
+ 121: 6.630420572209065
+ 3: 8.290823571978367
+ 4: 7.6847786232321535
+ 124: 6.822107522918532
+ 5: 6.939187568926209
+ 126: 7.608413202225105
+ 6: 8.340906801308865
+ 128: 5.072392719133257
+ 129: 3.5758878270546597
+ 130: 3.172918188296416
+ 17: 4.218440531352108
+ 16: 3.931176580980738
+ 133: 2.919954597731364
+ 134: 3.236737576939871
+ 135: 3.415454133426117
+ 18: 4.643995661949497
+ 7: 5.772164810850686
+ 138: 5.238776860999826
+ 139: 4.208475684527024
+ 8: 4.7641538901480445
+ 141: 4.113162874534946
+ 142: 4.625693800831237
+ 11: 4.451128264439006
+ 144: 4.168386107123375
+ 145: 5.1101945338282695
+ 12: 5.31672303499037
+ 147: 4.523658333756621
+ 148: 3.5122981836488223
+ 20: 7.080964010381011
+ 150: 6.377625659811748
+ 151: 4.928590462504693
+ 22: 4.701579190184112
+ 153: 4.727943873063984
+ 154: 5.465818611585125
+ 21: 6.148406910087535
+ 156: 6.786793649953492
+ 157: 6.186824707078167
+ 23: 6.837078051253573
+ 159: 7.605776962432469
+ 160: 7.235180906998915
+ 161: 7.982744434560323
+ 162: 4.656982439604284
+ 163: 4.270483732934878
+ 164: 4.6446884214498345
+ 165: 6.0823676390136034
+ 166: 6.9716680872367105
+ 167: 7.432470055536447
+ 168: 7.3663196613198
+ }
+ }
+ headpose {
+ mae_ypr: [1.6090495251118835, 2.046698079387603, 0.955561359343095]
+ mae_mean: 1.537102987947527
+ acc_label: [0.9183333333333333, 0.87, 0.9366666666666666]
+ acc_adj_label: [1.0, 1.0, 1.0]
+ sr_yaw {
+ 2.5: 0.7983333333333333
+ 5: 0.97
+ 10: 0.9983333333333333
+ 15: 1.0
+ 30: 1.0
+ }
+ sr_pitch {
+ 2.5: 0.6883333333333334
+ 5: 0.9416666666666667
+ 10: 0.9966666666666667
+ 15: 1.0
+ 30: 1.0
+ }
+ sr_roll {
+ 2.5: 0.9416666666666667
+ 5: 0.9983333333333333
+ 10: 1.0
+ 15: 1.0
+ 30: 1.0
+ }
+ }
+ }
diff --git a/SPIGA/spiga/eval/results/300wprivate/results_300wprivate_test.json b/SPIGA/spiga/eval/results/300wprivate/results_300wprivate_test.json
new file mode 100644
index 0000000000000000000000000000000000000000..99ab6e8efbddc025cefb3daef201088a9269eb15
--- /dev/null
+++ b/SPIGA/spiga/eval/results/300wprivate/results_300wprivate_test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c50536962725dd3cde6169f48fea7447314e17d1564ecf1921827704b6893af
+size 2223550
diff --git a/SPIGA/spiga/eval/results/300wpublic/metrics_300wpublic_test.txt b/SPIGA/spiga/eval/results/300wpublic/metrics_300wpublic_test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..49e1f7cd1309fe4bd100d1b3f8aee063b1973e95
--- /dev/null
+++ b/SPIGA/spiga/eval/results/300wpublic/metrics_300wpublic_test.txt
@@ -0,0 +1,353 @@
+Metrics {
+ landmarks {
+ nme: 2.993687267871256
+ nme_p90: 4.666037430932269
+ nme_p95: 5.435113114011185
+ nme_p99: 7.320552484553374
+ nme_thr: 8
+ nme_norm: corners
+ auc: 62.72750851023099
+ fr: 0.7256894049346929
+ subset {
+ helen {
+ nme: 2.5194569810704968
+ nme_p90: 3.588651040403328
+ nme_p95: 3.9771381372856305
+ nme_p99: 5.038565693708802
+ auc: 68.51294946769845
+ fr: 0.0
+ }
+ lfpw {
+ nme: 2.6867036799108392
+ nme_p90: 3.8500450455085744
+ nme_p95: 4.232240763191166
+ nme_p99: 5.2672893244075185
+ auc: 66.42133313963284
+ fr: 0.0
+ }
+ common {
+ nme: 2.5870801950420437
+ nme_p90: 3.710404666983561
+ nme_p95: 4.083736280045437
+ nme_p99: 5.216315947831376
+ auc: 67.6676731570753
+ fr: 0.0
+ }
+ ibug {
+ nme: 4.662282218592618
+ nme_p90: 6.628041366327217
+ nme_p95: 7.390973031124523
+ nme_p99: 10.093415564142603
+ auc: 42.44787101795513
+ fr: 3.703703703703709
+ }
+ }
+ nme_per_ldm {
+ 101: 6.160148291662601
+ 102: 5.609969055309763
+ 103: 5.639963193216567
+ 104: 5.700234717877294
+ 105: 5.62936644920689
+ 106: 5.574676189493879
+ 107: 5.349393371824218
+ 108: 4.702542530122117
+ 24: 4.168177272711497
+ 110: 4.485507899065954
+ 111: 5.208036708182383
+ 112: 5.493830104918947
+ 113: 5.586260490385025
+ 114: 5.714415950920614
+ 115: 5.442384369263486
+ 116: 5.4538915292008125
+ 117: 6.178730799959288
+ 1: 4.165885001962012
+ 119: 3.1921307867109445
+ 2: 2.6405652171792697
+ 121: 2.7692371208173987
+ 3: 3.771312716404229
+ 4: 3.915696771337146
+ 124: 3.0583414383575067
+ 5: 2.7836732387071392
+ 126: 2.9600464424059627
+ 6: 3.9325330705818566
+ 128: 2.266086276562485
+ 129: 1.879556319441715
+ 130: 1.9706376235225898
+ 17: 2.4439140053778123
+ 16: 2.648478293025234
+ 133: 2.1738938082518318
+ 134: 1.9817528088827188
+ 135: 1.7966616924282088
+ 18: 2.1889880619377
+ 7: 1.6798354870145646
+ 138: 1.4525365591905652
+ 139: 1.5444551246243863
+ 8: 1.6622930186366096
+ 141: 1.529130733663348
+ 142: 1.461878641713216
+ 11: 1.6372097638236516
+ 144: 1.3940225437140497
+ 145: 1.5293915903135549
+ 12: 1.6939793028036296
+ 147: 1.4966723651722735
+ 148: 1.441845674425855
+ 20: 1.81709289789314
+ 150: 1.9693359999114588
+ 151: 1.884337269187178
+ 22: 1.4862224166931166
+ 153: 1.9633109788224843
+ 154: 1.9772138173842613
+ 21: 1.9430947452720937
+ 156: 2.5919972321865066
+ 157: 2.561137981239902
+ 23: 2.1950781821923164
+ 159: 2.4901805745220082
+ 160: 2.550118126201779
+ 161: 1.9205944941018516
+ 162: 1.7900476486146397
+ 163: 1.3976363193641992
+ 164: 1.9106126973833983
+ 165: 2.0763873452059816
+ 166: 2.1224747018434096
+ 167: 1.693262532344654
+ 168: 2.070427832570231
+ }
+ npe90_per_ldm {
+ 101: 13.122683532466992
+ 102: 11.837902215558064
+ 103: 11.75616577370391
+ 104: 11.875509248652866
+ 105: 12.212341286473402
+ 106: 11.149962035433145
+ 107: 10.613398355229833
+ 108: 9.389892873095325
+ 24: 8.591199066081636
+ 110: 9.421511135997605
+ 111: 11.193086402244425
+ 112: 11.707000001316612
+ 113: 11.349684731339558
+ 114: 11.811597643003372
+ 115: 10.607833816884204
+ 116: 11.243376754752148
+ 117: 12.971478848872138
+ 1: 9.23874231532177
+ 119: 6.7372032763499545
+ 2: 5.33974784249774
+ 121: 5.732096724504876
+ 3: 7.5062452089485365
+ 4: 7.456439231223284
+ 124: 5.8237217172512095
+ 5: 5.37182931231504
+ 126: 6.094696584923242
+ 6: 8.17012872398533
+ 128: 4.198385133649584
+ 129: 3.3444779030854996
+ 130: 3.5638768938015297
+ 17: 4.429228974895804
+ 16: 4.869116170096254
+ 133: 4.157919854236839
+ 134: 4.108282493143575
+ 135: 3.522263446262345
+ 18: 4.002610414280894
+ 7: 3.2908041808066506
+ 138: 2.7722779484559354
+ 139: 3.1732425721021564
+ 8: 3.466774533899654
+ 141: 2.978701593114757
+ 142: 2.7587386420612185
+ 11: 3.2356256090137006
+ 144: 2.69672824715176
+ 145: 2.8984317775614516
+ 12: 3.3750413338539142
+ 147: 2.6896432032687887
+ 148: 2.488745347966042
+ 20: 3.4949560168183806
+ 150: 3.615117113066581
+ 151: 3.5894529408226603
+ 22: 2.909838203155513
+ 153: 3.7097208120818275
+ 154: 3.919346282690197
+ 21: 4.196338523026944
+ 156: 4.950309881678925
+ 157: 4.765821211342502
+ 23: 4.117335074736717
+ 159: 4.621291671397001
+ 160: 4.736419766006695
+ 161: 3.818460262745092
+ 162: 3.385764999285364
+ 163: 2.573588612350815
+ 164: 3.785005840288626
+ 165: 4.022827504716443
+ 166: 4.00963166181946
+ 167: 3.105493633989696
+ 168: 3.9076308361907075
+ }
+ npe95_per_ldm {
+ 101: 16.640887862533262
+ 102: 15.73249785912902
+ 103: 14.94766424497651
+ 104: 15.524963676616089
+ 105: 15.873365705127112
+ 106: 16.417111507079138
+ 107: 14.567853910225598
+ 108: 12.510205643500568
+ 24: 12.143881169760876
+ 110: 12.089065625344906
+ 111: 14.366730260926783
+ 112: 14.740499149458255
+ 113: 15.459087674693276
+ 114: 14.599722757007777
+ 115: 15.320290937083609
+ 116: 14.708246193314626
+ 117: 16.418629645261973
+ 1: 11.937667692240536
+ 119: 8.52572385838488
+ 2: 6.659827131118409
+ 121: 7.070926140929972
+ 3: 9.409775342546114
+ 4: 9.185034717032613
+ 124: 7.160591484765116
+ 5: 6.465838615220728
+ 126: 7.2916218490559945
+ 6: 10.085439984678748
+ 128: 4.945152308750049
+ 129: 4.240340551642683
+ 130: 4.235830937943613
+ 17: 5.07803313070256
+ 16: 6.301010594644189
+ 133: 5.3003686467661835
+ 134: 5.39773858718444
+ 135: 4.775021937005514
+ 18: 5.647719520272081
+ 7: 4.274205042229546
+ 138: 3.594392271949664
+ 139: 4.089779282378731
+ 8: 4.741065197374902
+ 141: 3.9364873945583154
+ 142: 3.3173099248264215
+ 11: 4.0045732629352635
+ 144: 3.158413099972695
+ 145: 3.6099848378771133
+ 12: 4.872231016970969
+ 147: 3.6122138787323426
+ 148: 3.1785153515663076
+ 20: 4.629010728226188
+ 150: 4.8588032103886345
+ 151: 4.64642981328001
+ 22: 3.6423999181932727
+ 153: 4.706447394770584
+ 154: 4.916454492667414
+ 21: 5.4326415927145755
+ 156: 6.8671568682767665
+ 157: 6.620902515780232
+ 23: 5.608019225929798
+ 159: 6.244705316060662
+ 160: 6.2049492392703645
+ 161: 4.621479019005777
+ 162: 4.386075192468076
+ 163: 3.4769696652222586
+ 164: 4.58677707682409
+ 165: 5.5239135136002755
+ 166: 5.487065873148972
+ 167: 4.446659252100746
+ 168: 5.294582035532603
+ }
+ npe99_per_ldm {
+ 101: 23.9472105829856
+ 102: 21.712780306667675
+ 103: 21.313324271514396
+ 104: 21.368750734861273
+ 105: 22.232228723373485
+ 106: 25.08953794210005
+ 107: 24.944096885619654
+ 108: 22.96040182762928
+ 24: 22.894704116944336
+ 110: 22.614380914459304
+ 111: 22.256763074721803
+ 112: 21.25918241444815
+ 113: 20.286955186210022
+ 114: 25.694990732800143
+ 115: 23.465421874951634
+ 116: 23.351073650613607
+ 117: 24.978511767730428
+ 1: 16.4324939088724
+ 119: 13.422956195988334
+ 2: 9.618581533279238
+ 121: 10.118878457527511
+ 3: 12.82211476714369
+ 4: 11.67716301954819
+ 124: 10.797328353981431
+ 5: 11.235362955909709
+ 126: 12.247003662455013
+ 6: 15.773704754299274
+ 128: 6.383363322637416
+ 129: 6.321260974342025
+ 130: 6.300611559944998
+ 17: 8.164659256640022
+ 16: 9.51053078772321
+ 133: 8.273369391152267
+ 134: 8.076100532657165
+ 135: 7.688675439981096
+ 18: 9.79951835561599
+ 7: 7.033074751907975
+ 138: 6.423074083674037
+ 139: 6.9446604878218565
+ 8: 7.574534227892147
+ 141: 6.5307976262292335
+ 142: 6.341758670051268
+ 11: 6.5760383486423315
+ 144: 5.286412075946667
+ 145: 7.240882514689094
+ 12: 7.739950501196232
+ 147: 7.44356729074967
+ 148: 5.468215521283997
+ 20: 8.286609944508777
+ 150: 8.069308096570612
+ 151: 7.411738191325378
+ 22: 6.142237156302458
+ 153: 7.701905468720823
+ 154: 8.52986712064459
+ 21: 9.300708464380659
+ 156: 13.14337152729244
+ 157: 11.376582767015504
+ 23: 11.795090745823842
+ 159: 11.10872358521376
+ 160: 11.855785481682954
+ 161: 7.618868444259899
+ 162: 6.610868401789532
+ 163: 6.220955897344144
+ 164: 8.126455622376907
+ 165: 9.37200137441875
+ 166: 10.480683232603582
+ 167: 8.139103679313033
+ 168: 10.707529142440695
+ }
+ }
+ headpose {
+ mae_ypr: [1.4116569256059106, 1.7011980952121855, 0.7716754190283938]
+ mae_mean: 1.29484347994883
+ acc_label: [0.8911465892597968, 0.8824383164005806, 0.9608127721335269]
+ acc_adj_label: [1.0, 1.0, 1.0]
+ sr_yaw {
+ 2.5: 0.8374455732946299
+ 5: 0.9796806966618288
+ 10: 0.9970972423802612
+ 15: 0.9985486211901307
+ 30: 1.0
+ }
+ sr_pitch {
+ 2.5: 0.7677793904208998
+ 5: 0.9753265602322206
+ 10: 1.0
+ 15: 1.0
+ 30: 1.0
+ }
+ sr_roll {
+ 2.5: 0.9564586357039188
+ 5: 1.0
+ 10: 1.0
+ 15: 1.0
+ 30: 1.0
+ }
+ }
+ }
diff --git a/SPIGA/spiga/eval/results/300wpublic/results_300wpublic_test.json b/SPIGA/spiga/eval/results/300wpublic/results_300wpublic_test.json
new file mode 100644
index 0000000000000000000000000000000000000000..52ec5e25f13fb246df02054310dba804eb4b6452
--- /dev/null
+++ b/SPIGA/spiga/eval/results/300wpublic/results_300wpublic_test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0aa3529879d1f265f2dedc8cb3057e715fdf56b37c3380f7b622c0720f9bfdc
+size 2556728
diff --git a/SPIGA/spiga/eval/results/cofw68/metrics_cofw68_test.txt b/SPIGA/spiga/eval/results/cofw68/metrics_cofw68_test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6b4037dcbd6c0867dc7159fd305f46d455003b72
--- /dev/null
+++ b/SPIGA/spiga/eval/results/cofw68/metrics_cofw68_test.txt
@@ -0,0 +1,319 @@
+Metrics {
+ landmarks {
+ nme: 2.517048299761656
+ nme_p90: 3.439260118578336
+ nme_p95: 4.065583073314784
+ nme_p99: 5.5576659368466075
+ nme_thr: 7
+ nme_norm: lnd_bbox
+ auc: 64.04980282900947
+ fr: 0.0
+ nme_per_ldm {
+ 101: 4.386618760988867
+ 102: 4.061629839664334
+ 103: 3.983745459766538
+ 104: 4.059943179619045
+ 105: 3.901502523065625
+ 106: 3.6464161308997243
+ 107: 3.3236750419605943
+ 108: 2.88736723178498
+ 24: 2.7423694672096457
+ 110: 3.0635795053213557
+ 111: 3.571088090891605
+ 112: 3.9983500297313452
+ 113: 4.0943322771194595
+ 114: 4.255800757059301
+ 115: 4.2912918761153485
+ 116: 4.266023792408103
+ 117: 4.530578189494549
+ 1: 3.158660231370845
+ 119: 3.4125315663662614
+ 2: 3.352789113391728
+ 121: 3.417363298999408
+ 3: 3.668232431297081
+ 4: 3.91431465693854
+ 124: 3.840416726980644
+ 5: 3.6046913869421213
+ 126: 3.387491193444403
+ 6: 3.0296136285985975
+ 128: 1.8998212393466836
+ 129: 1.516826726048447
+ 130: 1.5100915342372372
+ 17: 1.6919453984466715
+ 16: 1.421215739597502
+ 133: 1.5697432942669043
+ 134: 1.4987569995929837
+ 135: 1.366172570997096
+ 18: 1.2916655495623037
+ 7: 1.7581540811039842
+ 138: 1.5428715462019915
+ 139: 1.3632059460397556
+ 8: 1.3302291252462273
+ 141: 1.287550535022652
+ 142: 1.4467690355725018
+ 11: 1.2907119544949173
+ 144: 1.5873275912342366
+ 145: 1.9198143064375461
+ 12: 2.2533648088916247
+ 147: 1.747890660026784
+ 148: 1.436164845643807
+ 20: 1.991609432994987
+ 150: 1.8109854987255802
+ 151: 1.6953683686244805
+ 22: 1.5060260323923602
+ 153: 1.708202891919297
+ 154: 1.6963691562752732
+ 21: 1.87731093602303
+ 156: 2.142094773425542
+ 157: 2.240978526071491
+ 23: 2.110667241896683
+ 159: 2.2746052761540887
+ 160: 2.2265284572934942
+ 161: 2.1973066159601427
+ 162: 1.7926476622155938
+ 163: 1.5326108170595516
+ 164: 1.729851588257375
+ 165: 2.121233580025467
+ 166: 1.9985535852655705
+ 167: 1.845910843504295
+ 168: 2.0797132242663703
+ }
+ npe90_per_ldm {
+ 101: 8.018897558697356
+ 102: 7.97010628672223
+ 103: 7.566745648570134
+ 104: 8.095347755811265
+ 105: 7.756765075016777
+ 106: 7.039911855937849
+ 107: 6.288952271483205
+ 108: 5.773784683107754
+ 24: 5.945580939663232
+ 110: 5.829576040208693
+ 111: 7.112875146109584
+ 112: 7.940662690831719
+ 113: 8.250074675668332
+ 114: 8.07748482264541
+ 115: 8.233109448421368
+ 116: 8.242831765389505
+ 117: 8.687203940264425
+ 1: 5.86847210127666
+ 119: 5.243952556782381
+ 2: 5.137097829449687
+ 121: 5.2659079859150815
+ 3: 5.857022822554752
+ 4: 6.151300083812946
+ 124: 5.789014437973446
+ 5: 5.372824967591501
+ 126: 5.475210902866383
+ 6: 5.86178580377768
+ 128: 3.2289389333211234
+ 129: 2.622261396524905
+ 130: 2.7304992712176763
+ 17: 3.1572454384592046
+ 16: 2.452719009880311
+ 133: 2.627498132517079
+ 134: 2.7312065493277164
+ 135: 2.2973900896540504
+ 18: 2.0929480199891537
+ 7: 3.65244010654807
+ 138: 3.1776907692766656
+ 139: 2.8232017769140567
+ 8: 2.450986276686096
+ 141: 2.491367195541456
+ 142: 2.914612560836842
+ 11: 2.510987040195006
+ 144: 2.9695834676241772
+ 145: 3.4200280229584106
+ 12: 3.949345153264541
+ 147: 3.1944286922995526
+ 148: 2.487523608598698
+ 20: 3.975484769060534
+ 150: 3.385725078332356
+ 151: 3.285296446152331
+ 22: 2.7130471412453243
+ 153: 2.992534087755148
+ 154: 2.9411914409794786
+ 21: 3.4939899717823275
+ 156: 4.021007684523546
+ 157: 4.190985454904701
+ 23: 4.130310029590604
+ 159: 4.314667308966008
+ 160: 4.272504657006288
+ 161: 4.685153554561627
+ 162: 3.203495263181351
+ 163: 2.747330827383441
+ 164: 3.0265766339171187
+ 165: 4.500497180240332
+ 166: 3.8905157693437964
+ 167: 3.891106312239455
+ 168: 3.8933988592268465
+ }
+ npe95_per_ldm {
+ 101: 10.417038072470968
+ 102: 9.991198289523819
+ 103: 9.379720137555848
+ 104: 9.799656008084403
+ 105: 9.22207737018875
+ 106: 8.829078350035305
+ 107: 7.992799269180475
+ 108: 7.4082148935339465
+ 24: 7.413781438652343
+ 110: 7.37361147503038
+ 111: 8.709348785745767
+ 112: 9.312892440791485
+ 113: 9.77808155435385
+ 114: 10.325717515744035
+ 115: 10.033803637246558
+ 116: 10.004145195487656
+ 117: 10.821755638172457
+ 1: 6.765196576889283
+ 119: 5.947330615300408
+ 2: 5.862662969499837
+ 121: 6.152025258743905
+ 3: 7.1670021067469625
+ 4: 7.454245442904308
+ 124: 6.756637702306873
+ 5: 6.1635352411726165
+ 126: 6.519119512853371
+ 6: 7.1940984593611095
+ 128: 3.6571416433436674
+ 129: 3.056401499324455
+ 130: 3.3808439264483634
+ 17: 4.061928034301918
+ 16: 2.9407680765002353
+ 133: 3.0653183161807847
+ 134: 3.180963980499353
+ 135: 2.901656649057088
+ 18: 2.840245418802563
+ 7: 4.597797894095366
+ 138: 4.3035882561965675
+ 139: 3.7014036924353197
+ 8: 3.5013324792600873
+ 141: 3.2526250820240743
+ 142: 3.9174824517972575
+ 11: 3.63472057362023
+ 144: 4.191612176955886
+ 145: 4.520173157846678
+ 12: 4.912688917799541
+ 147: 4.356777295067304
+ 148: 3.8592166644229673
+ 20: 5.541101603402795
+ 150: 4.834433689727549
+ 151: 4.3500326868388095
+ 22: 3.813502938654715
+ 153: 4.389537008701319
+ 154: 4.776232703613533
+ 21: 4.980501036793917
+ 156: 5.277853337613837
+ 157: 5.740403588832144
+ 23: 5.724463408255806
+ 159: 5.932445207932544
+ 160: 5.8883880784816025
+ 161: 6.852697999733049
+ 162: 4.380159191875369
+ 163: 3.5470116495641206
+ 164: 3.992186775614564
+ 165: 5.563741112172766
+ 166: 5.302063622441587
+ 167: 5.434068653301774
+ 168: 5.857697902700627
+ }
+ npe99_per_ldm {
+ 101: 15.088755845432772
+ 102: 13.073819615334575
+ 103: 13.477893246441145
+ 104: 12.545880504907462
+ 105: 12.097862148538754
+ 106: 12.268080481835584
+ 107: 11.25379241268216
+ 108: 10.828929543085886
+ 24: 11.717651103815243
+ 110: 11.699755871901786
+ 111: 12.184678792172349
+ 112: 13.34509981517936
+ 113: 11.83952843526531
+ 114: 12.853085381739312
+ 115: 13.994438403474529
+ 116: 13.081667763696956
+ 117: 14.42657011217481
+ 1: 8.916371368807438
+ 119: 7.997584620353168
+ 2: 7.610065164179458
+ 121: 7.6362815228963035
+ 3: 8.797749915381567
+ 4: 9.865833274647276
+ 124: 7.910805681261553
+ 5: 8.048606667984895
+ 126: 8.619571570257547
+ 6: 9.183003733406897
+ 128: 4.865828948220972
+ 129: 4.287177166390985
+ 130: 5.352060488383854
+ 17: 6.229529147494782
+ 16: 4.921147034354624
+ 133: 5.357211627886322
+ 134: 5.565350192932446
+ 135: 5.716824998156575
+ 18: 5.408026129580792
+ 7: 6.784161819068971
+ 138: 6.674934894826747
+ 139: 5.596098482770631
+ 8: 4.810708425032517
+ 141: 5.29448127755881
+ 142: 5.587101373823092
+ 11: 6.561695360938204
+ 144: 6.190677292735184
+ 145: 7.575832280733528
+ 12: 7.970272010646224
+ 147: 7.102125471669491
+ 148: 6.117951030854053
+ 20: 9.733188811612486
+ 150: 8.108968351347654
+ 151: 8.053888878839329
+ 22: 8.501392906622689
+ 153: 7.842154498433768
+ 154: 9.348911008680913
+ 21: 10.794637315024282
+ 156: 11.4410194218526
+ 157: 10.844758689065458
+ 23: 10.508751739629794
+ 159: 10.397173521760807
+ 160: 9.493416503817661
+ 161: 11.033949636665236
+ 162: 8.776212120745663
+ 163: 9.80338512675157
+ 164: 9.456915239371774
+ 165: 11.727340045598552
+ 166: 11.91779218298747
+ 167: 11.234236463238254
+ 168: 11.068899102792699
+ }
+ }
+ headpose {
+ mae_ypr: [2.083062363064235, 2.475056062512911, 0.9064911102818425]
+ mae_mean: 1.8215365119529963
+ acc_label: [0.8875739644970414, 0.8520710059171598, 0.9388560157790927]
+ acc_adj_label: [1.0, 1.0, 1.0]
+ sr_yaw {
+ 2.5: 0.6923076923076923
+ 5: 0.9368836291913215
+ 10: 0.9960552268244576
+ 15: 1.0
+ 30: 1.0
+ }
+ sr_pitch {
+ 2.5: 0.5798816568047337
+ 5: 0.893491124260355
+ 10: 0.9940828402366864
+ 15: 1.0
+ 30: 1.0
+ }
+ sr_roll {
+ 2.5: 0.9428007889546351
+ 5: 0.9980276134122288
+ 10: 1.0
+ 15: 1.0
+ 30: 1.0
+ }
+ }
+ }
diff --git a/SPIGA/spiga/eval/results/cofw68/results_cofw68_test.json b/SPIGA/spiga/eval/results/cofw68/results_cofw68_test.json
new file mode 100644
index 0000000000000000000000000000000000000000..79eaaa13cbfce5b7b224305fd4757fbe98263ac6
--- /dev/null
+++ b/SPIGA/spiga/eval/results/cofw68/results_cofw68_test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03acfc2ede335aac354c9fe2c9bd008beb9d6d596210c74b4231a95b6687dd9b
+size 1898116
diff --git a/SPIGA/spiga/eval/results/merlrav/metrics_merlrav_test.txt b/SPIGA/spiga/eval/results/merlrav/metrics_merlrav_test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..582c650e219abb322077e3608d40f8dd8bc7b533
--- /dev/null
+++ b/SPIGA/spiga/eval/results/merlrav/metrics_merlrav_test.txt
@@ -0,0 +1,377 @@
+Metrics {
+ landmarks {
+ nme: 1.5086804931543236
+ nme_p90: 2.1635501319152364
+ nme_p95: 2.466109589313714
+ nme_p99: 3.455642125258591
+ nme_thr: 7
+ nme_norm: lnd_bbox
+ auc: 78.47397661438411
+ fr: 0.051746442432087925
+ subset {
+ frontal {
+ nme: 1.61560462312357
+ nme_p90: 2.2463951142028145
+ nme_p95: 2.572226598305855
+ nme_p99: 3.6210860696886833
+ auc: 76.9597525877967
+ fr: 0.09111617312073106
+ }
+ half_profile {
+ nme: 1.6828487909060967
+ nme_p90: 2.272702728357507
+ nme_p95: 2.5471145536327953
+ nme_p99: 3.3987905213552203
+ auc: 75.9657914896377
+ fr: 0.0
+ }
+ half_profile_left {
+ nme: 1.6593883465198958
+ nme_p90: 2.3284338844722074
+ nme_p95: 2.5693632263530124
+ nme_p99: 3.278697689229699
+ auc: 76.30098571769051
+ fr: 0.0
+ }
+ half_profile_right {
+ nme: 1.7054658859842966
+ nme_p90: 2.220904448144333
+ nme_p95: 2.491175396697724
+ nme_p99: 3.5122625056995673
+ auc: 75.64173351430159
+ fr: 0.0
+ }
+ profile {
+ nme: 1.1912121935927615
+ nme_p90: 1.7318457133250305
+ nme_p95: 2.04549650792864
+ nme_p99: 2.878427460092048
+ auc: 82.98987188287711
+ fr: 0.0
+ }
+ profile_left {
+ nme: 1.1429927541204028
+ nme_p90: 1.6430975490859832
+ nme_p95: 1.9478367509011818
+ nme_p99: 2.63876055903246
+ auc: 83.67867425299369
+ fr: 0.0
+ }
+ profile_right {
+ nme: 1.2370558577626498
+ nme_p90: 1.810307762848865
+ nme_p95: 2.1411266828074176
+ nme_p99: 2.9389777318895325
+ auc: 82.332208276089
+ fr: 0.0
+ }
+ }
+ nme_per_ldm {
+ 101: 3.3589768431762175
+ 102: 2.7531280870845873
+ 103: 2.859701002904713
+ 104: 2.8634457682779235
+ 105: 2.877658841330188
+ 106: 2.720071048885179
+ 107: 2.4629924994495793
+ 108: 2.2467294961158877
+ 24: 2.1288254355152607
+ 110: 2.2724448767452325
+ 111: 2.540627823930061
+ 112: 2.8284471937158453
+ 113: 3.012202614117622
+ 114: 3.0464652090359206
+ 115: 3.0443508364827943
+ 116: 2.9206874592197822
+ 117: 3.5246519565195786
+ 1: 2.4989484892381864
+ 119: 2.094527573540818
+ 2: 1.9975862837642682
+ 121: 1.9667532281958748
+ 3: 2.1557279059194916
+ 4: 2.036801626056737
+ 124: 1.825762147391309
+ 5: 1.9798960712376632
+ 126: 2.205651215578999
+ 6: 2.6494703945724276
+ 128: 1.4815483716437434
+ 129: 1.231528474400865
+ 130: 1.247063433631788
+ 17: 1.4741665492704437
+ 16: 1.204987012197412
+ 133: 1.0170302588663973
+ 134: 1.0350961889009112
+ 135: 1.0078531087065747
+ 18: 1.229081438697912
+ 7: 1.2964041628692098
+ 138: 1.1803380166768218
+ 139: 1.1907173974524987
+ 8: 1.1245143408558638
+ 141: 1.1016743161735143
+ 142: 1.1321081964660915
+ 11: 1.0727780654424217
+ 144: 1.1618244167936957
+ 145: 1.2171458580396803
+ 12: 1.3729999278923841
+ 147: 1.1778922447451972
+ 148: 1.1254606806695924
+ 20: 1.0895438485495015
+ 150: 1.1152916781328186
+ 151: 1.186922081665385
+ 22: 1.0107909205470527
+ 153: 1.1651654397855509
+ 154: 1.1413200775970356
+ 21: 1.1505719973688346
+ 156: 1.2323070958056401
+ 157: 1.2951471399634233
+ 23: 1.3496263569400522
+ 159: 1.3166097535695698
+ 160: 1.185578952219307
+ 161: 1.1289970069187725
+ 162: 1.1741022347060504
+ 163: 1.0222564924321926
+ 164: 1.182969390549842
+ 165: 1.1677089625147783
+ 166: 1.4000000311093093
+ 167: 1.1875541874658213
+ 168: 1.3548409427314658
+ }
+ npe90_per_ldm {
+ 101: 6.841303462812936
+ 102: 5.871967061156423
+ 103: 6.435313555001839
+ 104: 6.0863414215710545
+ 105: 5.856386610328998
+ 106: 5.3515838953111325
+ 107: 4.936905109240317
+ 108: 4.35646958881554
+ 24: 4.154501717080872
+ 110: 4.366467634101041
+ 111: 5.007177442427851
+ 112: 5.75662407531498
+ 113: 6.147574364558974
+ 114: 6.5819160311808504
+ 115: 7.242887674595011
+ 116: 6.139449989740605
+ 117: 6.913785566410088
+ 1: 5.264989646117678
+ 119: 4.281934194047823
+ 2: 4.078521868372654
+ 121: 3.9081540494658076
+ 3: 4.243355566911822
+ 4: 4.016214823124536
+ 124: 3.688705902703632
+ 5: 3.9564376406091695
+ 126: 4.667716897467507
+ 6: 5.730550741940217
+ 128: 2.671781950666222
+ 129: 2.191045192348473
+ 130: 2.238379783884833
+ 17: 2.608140372894774
+ 16: 2.1748623755375593
+ 133: 1.7634276968625462
+ 134: 1.8924906144995726
+ 135: 1.7874400893528062
+ 18: 2.2361018447957703
+ 7: 2.5174159955382107
+ 138: 2.2777400322113186
+ 139: 2.230048482303937
+ 8: 2.191010919936544
+ 141: 2.097821997361514
+ 142: 2.139694713995162
+ 11: 2.1575823638269345
+ 144: 2.1801001764063463
+ 145: 2.3478125097855127
+ 12: 2.6692055655031304
+ 147: 2.2042056800179175
+ 148: 2.0837791212171304
+ 20: 2.0762878643553657
+ 150: 2.069576194472468
+ 151: 2.2168224383982555
+ 22: 1.8675655632850923
+ 153: 2.2593226852473003
+ 154: 2.169092157346306
+ 21: 2.2776575357980002
+ 156: 2.313307796442557
+ 157: 2.4325564157550295
+ 23: 2.466093276107024
+ 159: 2.457834012145091
+ 160: 2.2375426349501075
+ 161: 2.2227197578057774
+ 162: 2.2256901798461644
+ 163: 1.9237306114380308
+ 164: 2.2285056022065644
+ 165: 2.2872963189685924
+ 166: 2.6474074126240037
+ 167: 2.266259572666768
+ 168: 2.6563960731826053
+ }
+ npe95_per_ldm {
+ 101: 8.441117987108047
+ 102: 7.809974197367657
+ 103: 8.914018643798276
+ 104: 7.985589833529736
+ 105: 7.657610856797052
+ 106: 6.901529644151474
+ 107: 6.003484363104082
+ 108: 5.359714843390148
+ 24: 5.237271616781721
+ 110: 5.35817329527914
+ 111: 6.093353283047422
+ 112: 7.178776732921367
+ 113: 7.864851413344902
+ 114: 8.435585935405099
+ 115: 9.31943075203189
+ 116: 8.338565642906014
+ 117: 8.844323862195974
+ 1: 6.920610608658486
+ 119: 5.579137925444407
+ 2: 5.118546782224291
+ 121: 4.9331439703683095
+ 3: 5.3920267113612
+ 4: 4.991568250091611
+ 124: 4.567737138187528
+ 5: 5.007922132090623
+ 126: 5.8629287499286376
+ 6: 7.279068497806011
+ 128: 3.277869272288929
+ 129: 2.6265009714322414
+ 130: 2.598266674561741
+ 17: 3.1108016117707487
+ 16: 2.6423645953005237
+ 133: 2.1061442879934087
+ 134: 2.3322800165604987
+ 135: 2.1330354379938354
+ 18: 2.7965749304352037
+ 7: 3.28428298493539
+ 138: 2.9011942090088265
+ 139: 2.8616896234789064
+ 8: 2.965395206958521
+ 141: 2.7671326920439228
+ 142: 2.720548244959449
+ 11: 2.7847408257771025
+ 144: 2.8207150460715664
+ 145: 3.0690144016397145
+ 12: 3.5641601474934723
+ 147: 2.9816894630361896
+ 148: 2.7772131964014806
+ 20: 2.9545877208209697
+ 150: 2.676521986087967
+ 151: 2.8397352668704365
+ 22: 2.452843817529208
+ 153: 2.786238880172479
+ 154: 2.6948539822082376
+ 21: 3.2014971657831786
+ 156: 2.99331035258674
+ 157: 3.099565507520246
+ 23: 3.1536920132264137
+ 159: 3.035724672323389
+ 160: 2.8590697596373205
+ 161: 3.18328939939261
+ 162: 2.745295751539974
+ 163: 2.491584206354821
+ 164: 2.7848191019354513
+ 165: 3.25229943328409
+ 166: 3.384820152838584
+ 167: 2.972183782174847
+ 168: 3.393574192195915
+ }
+ npe99_per_ldm {
+ 101: 13.306254348151949
+ 102: 12.981243433347958
+ 103: 14.194854214648371
+ 104: 12.773848328599131
+ 105: 11.310276344499538
+ 106: 10.033189122433079
+ 107: 8.737335443356818
+ 108: 8.689656343619724
+ 24: 8.813631497738173
+ 110: 8.98053893671926
+ 111: 9.200895008592042
+ 112: 10.574307218972542
+ 113: 11.80779942179725
+ 114: 12.831817823579
+ 115: 14.90723684715645
+ 116: 13.403174325402333
+ 117: 14.32637544856898
+ 1: 11.107329717515297
+ 119: 8.819292038261858
+ 2: 7.998693376996145
+ 121: 7.6630438328081265
+ 3: 8.180601672453921
+ 4: 7.1172156178272115
+ 124: 6.883517943873057
+ 5: 7.469622468883537
+ 126: 9.260935533221096
+ 6: 11.394333320505359
+ 128: 4.985104181306873
+ 129: 3.6610539607732746
+ 130: 3.577669553046159
+ 17: 4.282318668694788
+ 16: 4.4367360979513535
+ 133: 3.4077506728617175
+ 134: 3.5073978351923065
+ 135: 3.072085422485997
+ 18: 3.8561367097286876
+ 7: 5.243996543352382
+ 138: 4.736214805102126
+ 139: 4.658932237258348
+ 8: 5.035769413543924
+ 141: 4.859906333306005
+ 142: 4.687698979363885
+ 11: 4.754515314723082
+ 144: 4.7954279914279745
+ 145: 5.177008093991888
+ 12: 6.376708713699464
+ 147: 5.357222555071042
+ 148: 4.726237540189613
+ 20: 5.785803524452584
+ 150: 4.7338985598610925
+ 151: 4.626492612315765
+ 22: 4.163245408353465
+ 153: 4.496380103516062
+ 154: 4.341338388123191
+ 21: 6.061105262023478
+ 156: 5.334943775605292
+ 157: 5.152214838495613
+ 23: 5.874569691242533
+ 159: 5.791913011910687
+ 160: 5.062634991173862
+ 161: 5.835847336978609
+ 162: 4.5226670212359945
+ 163: 4.242683346393679
+ 164: 4.30806562074895
+ 165: 6.014068598402919
+ 166: 5.471812890679048
+ 167: 5.600333148651882
+ 168: 5.405042102742384
+ }
+ }
+ headpose {
+ mae_ypr: [3.2345458128780242, 2.2386364658267817, 1.7062846932804483]
+ mae_mean: 2.393155657328418
+ acc_label: [0.8087968952134541, 0.8543337645536869, 0.8892626131953428]
+ acc_adj_label: [0.9886157826649418, 0.9992238033635188, 0.9979301423027167]
+ sr_yaw {
+ 2.5: 0.6530401034928849
+ 5: 0.8369987063389392
+ 10: 0.9340232858990944
+ 15: 0.9707632600258732
+ 30: 0.9937904269081501
+ }
+ sr_pitch {
+ 2.5: 0.6763260025873221
+ 5: 0.9099611901681759
+ 10: 0.9912031047865459
+ 15: 0.9979301423027167
+ 30: 0.9997412677878396
+ }
+ sr_roll {
+ 2.5: 0.8243208279430789
+ 5: 0.9291073738680465
+ 10: 0.9800776196636481
+ 15: 0.9943078913324709
+ 30: 0.9987063389391979
+ }
+ }
+ }
diff --git a/SPIGA/spiga/eval/results/merlrav/results_merlrav_test.json b/SPIGA/spiga/eval/results/merlrav/results_merlrav_test.json
new file mode 100644
index 0000000000000000000000000000000000000000..e34420cc35f14668037cad7777bfdb5198213e14
--- /dev/null
+++ b/SPIGA/spiga/eval/results/merlrav/results_merlrav_test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:490f620f568d5cced54f61cf9b0a449c7e414c0f90f659290935d39811cbe1ee
+size 14445509
diff --git a/SPIGA/spiga/eval/results/wflw/metrics_wflw_test.txt b/SPIGA/spiga/eval/results/wflw/metrics_wflw_test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e6438fcfcc347c37561e100f1d0530679468f224
--- /dev/null
+++ b/SPIGA/spiga/eval/results/wflw/metrics_wflw_test.txt
@@ -0,0 +1,489 @@
+Metrics {
+ landmarks {
+ nme: 4.060449569787332
+ nme_p90: 6.765928760955893
+ nme_p95: 8.197544910122737
+ nme_p99: 13.090868491915499
+ nme_thr: 10
+ nme_norm: corners
+ auc: 60.55545281848933
+ fr: 2.080000000000004
+ subset {
+ pose {
+ nme: 7.1406786522793455
+ nme_p90: 10.683591623734115
+ nme_p95: 13.331614480116997
+ nme_p99: 26.88728473439624
+ auc: 35.30522609772822
+ fr: 11.65644171779141
+ }
+ expression {
+ nme: 4.457062543676477
+ nme_p90: 7.023696522424334
+ nme_p95: 8.14882559132701
+ nme_p99: 22.37794722994855
+ auc: 57.97977206967515
+ fr: 2.2292993630573243
+ }
+ illumination {
+ nme: 4.004404270660319
+ nme_p90: 6.5263439186718575
+ nme_p95: 7.9198072110541515
+ nme_p99: 11.091641927095278
+ auc: 61.30881713357479
+ fr: 1.5759312320916874
+ }
+ makeup {
+ nme: 3.809466022861384
+ nme_p90: 6.319715813961875
+ nme_p95: 8.290862588901753
+ nme_p99: 11.568122626675834
+ auc: 62.227636273476094
+ fr: 1.4563106796116498
+ }
+ occlusion {
+ nme: 4.951892714023863
+ nme_p90: 8.090266937174373
+ nme_p95: 9.927929968392291
+ nme_p99: 16.439363635459827
+ auc: 53.30944329221704
+ fr: 4.483695652173914
+ }
+ blur {
+ nme: 4.6502632721001484
+ nme_p90: 7.309129270660999
+ nme_p95: 8.694682218622555
+ nme_p99: 14.421681667881499
+ auc: 55.30994528611606
+ fr: 2.1992238033635148
+ }
+ }
+ nme_per_ldm {
+ 100: 7.499428096385048
+ 101: 7.051152856495479
+ 102: 6.689364840366089
+ 103: 6.405394902267039
+ 104: 6.184828312751848
+ 105: 6.01289637200206
+ 106: 5.825792191666405
+ 107: 5.703581143874681
+ 108: 5.558218880650258
+ 109: 5.456541574926521
+ 110: 5.356011590771132
+ 111: 5.266801213177881
+ 112: 5.1700746207919135
+ 113: 5.11578920243509
+ 114: 5.106880510341512
+ 115: 5.225113359917539
+ 24: 5.347111701530893
+ 117: 5.193400965915828
+ 118: 5.105516141121586
+ 119: 5.1186167002866485
+ 120: 5.154032577508115
+ 121: 5.275157240183345
+ 122: 5.389493979946993
+ 123: 5.4938752374179955
+ 124: 5.584092334651998
+ 125: 5.705867605198732
+ 126: 5.799339095880129
+ 127: 5.954134417791037
+ 128: 6.176961828664217
+ 129: 6.455263479287812
+ 130: 6.768471302761961
+ 131: 7.125915842598694
+ 132: 7.5666080900616945
+ 1: 4.560765729125988
+ 134: 3.7993917303919567
+ 2: 3.4016672801102077
+ 136: 3.345044810053455
+ 3: 3.710503962280294
+ 138: 3.6149551366473753
+ 139: 3.2410971507433786
+ 140: 3.291757973260253
+ 141: 3.734225865973348
+ 4: 3.72032794127556
+ 143: 3.30955002587817
+ 5: 3.4392003178872335
+ 145: 3.8880208167570074
+ 6: 4.752078564996328
+ 147: 3.738086948109648
+ 148: 3.3201253179463905
+ 149: 3.2295749512632232
+ 150: 3.647357521767312
+ 151: 2.547342291744557
+ 152: 2.283238387886172
+ 153: 2.591598229160512
+ 17: 3.0354855482365775
+ 16: 3.2137738380189536
+ 156: 2.518324235151659
+ 157: 2.47653726483443
+ 158: 2.5353874064282507
+ 18: 3.2251412484538893
+ 7: 3.1441893915300385
+ 161: 2.4942204053480195
+ 9: 2.6590371025659896
+ 163: 2.3580909255854596
+ 8: 2.8036185690133753
+ 165: 2.3321001401925967
+ 10: 2.591633267462054
+ 167: 2.484735627041701
+ 11: 2.7305098161561356
+ 169: 2.4274079065187864
+ 13: 2.7842155581430643
+ 171: 2.52221107229687
+ 12: 3.0657845511698607
+ 173: 2.494593526480273
+ 14: 2.6932726730687575
+ 175: 2.3591528479046104
+ 20: 4.112721320016711
+ 177: 3.275914306275302
+ 178: 3.42575723113226
+ 22: 3.0927109768973287
+ 180: 3.432414088055732
+ 181: 3.2103207081848937
+ 21: 4.432567252126953
+ 183: 3.7233446858330366
+ 184: 3.539979963579465
+ 23: 3.6190759220532405
+ 186: 3.4824407313700387
+ 187: 3.6566723766201465
+ 188: 3.472835226709668
+ 189: 2.726895196789255
+ 190: 2.930735583182004
+ 191: 2.820444141794683
+ 192: 3.7859629623364324
+ 193: 3.1790111029427828
+ 194: 3.3227123023323206
+ 195: 3.0754071648176375
+ 196: 2.2687775736645497
+ 197: 2.3802949379562897
+ }
+ npe90_per_ldm {
+ 100: 14.852445597799203
+ 101: 14.18579862958194
+ 102: 13.399793537733599
+ 103: 12.718468904804915
+ 104: 12.169341275908627
+ 105: 12.01705017248965
+ 106: 11.552579287899306
+ 107: 11.519360109572814
+ 108: 11.069713815819815
+ 109: 10.974822012933272
+ 110: 10.969423983927014
+ 111: 10.643665463349626
+ 112: 10.508696789450717
+ 113: 10.509353382470481
+ 114: 10.426225048917066
+ 115: 10.77132409296981
+ 24: 10.836704047501668
+ 117: 10.227534559111756
+ 118: 10.212154240626676
+ 119: 10.343173662305139
+ 120: 10.501815588300182
+ 121: 10.60403011069595
+ 122: 10.885667824449156
+ 123: 11.223511777755174
+ 124: 11.426450532559125
+ 125: 11.731786380223648
+ 126: 11.87994642089011
+ 127: 12.347922569924979
+ 128: 12.510717574506758
+ 129: 13.203110126923216
+ 130: 13.935800451738846
+ 131: 14.65581284528847
+ 132: 15.684327289240702
+ 1: 9.145904257144307
+ 134: 7.391833366795161
+ 2: 6.557342516262404
+ 136: 6.326639388017908
+ 3: 7.123186160526489
+ 138: 6.737679958469918
+ 139: 6.122406459193663
+ 140: 6.32844967785714
+ 141: 7.423740406828487
+ 4: 6.980971721786448
+ 143: 6.45515651585541
+ 5: 6.862006338881268
+ 145: 7.679578611775045
+ 6: 9.747986111715123
+ 147: 7.334343333255734
+ 148: 6.579826532395581
+ 149: 6.304752067636323
+ 150: 6.821268539263581
+ 151: 4.755299967018147
+ 152: 4.1313479515649725
+ 153: 4.739782037447192
+ 17: 5.5336191688254415
+ 16: 6.437315379147384
+ 156: 4.838950427255811
+ 157: 5.138543963702374
+ 158: 4.905267413061354
+ 18: 6.107948731858525
+ 7: 6.182595210311924
+ 161: 4.856248707531849
+ 9: 5.256125568483426
+ 163: 4.51061528950398
+ 8: 5.462974340475309
+ 165: 4.517984859757126
+ 10: 4.880297646197902
+ 167: 4.851486229126095
+ 11: 5.315306219280171
+ 169: 4.753220539425407
+ 13: 5.374371835997992
+ 171: 4.852785312871789
+ 12: 6.137253720508925
+ 173: 4.902626446768939
+ 14: 5.062119144586851
+ 175: 4.477753808871465
+ 20: 8.291468584753794
+ 177: 6.865894383817322
+ 178: 6.6278523121906945
+ 22: 6.391876354854704
+ 180: 6.737756386246774
+ 181: 6.633701192824108
+ 21: 8.828649325728422
+ 183: 7.428439146636114
+ 184: 7.215668900651159
+ 23: 7.2506582758423885
+ 186: 6.971699097467
+ 187: 7.16787193303953
+ 188: 7.134659063702044
+ 189: 5.6504003691320195
+ 190: 6.08694986471154
+ 191: 5.698565063478801
+ 192: 7.807921223112673
+ 193: 6.27796544518223
+ 194: 6.6953497695329105
+ 195: 5.963853719652496
+ 196: 4.42588286715022
+ 197: 4.758802842487979
+ }
+ npe95_per_ldm {
+ 100: 18.8274259243322
+ 101: 17.91893820103512
+ 102: 16.790440459339383
+ 103: 16.36177051582291
+ 104: 15.72781641631692
+ 105: 15.041278647748332
+ 106: 14.763833719969805
+ 107: 15.025060221978926
+ 108: 14.730682520252282
+ 109: 14.643754841675316
+ 110: 14.300136361811015
+ 111: 14.654414320540273
+ 112: 14.593583232262878
+ 113: 14.47409072542243
+ 114: 14.618637715471197
+ 115: 14.530093031225332
+ 24: 14.995606363223226
+ 117: 14.448948156208589
+ 118: 13.789274304854752
+ 119: 13.832120269244403
+ 120: 13.709518365737859
+ 121: 14.129329077634093
+ 122: 14.353242447422213
+ 123: 15.16466545358333
+ 124: 15.164929443920798
+ 125: 15.113599392011784
+ 126: 15.118919334522698
+ 127: 15.493269395291685
+ 128: 16.38552396477743
+ 129: 17.243501461747186
+ 130: 17.7726200230872
+ 131: 18.733413340844084
+ 132: 19.85817362564012
+ 1: 12.282671457428942
+ 134: 10.190266338206994
+ 2: 8.627462529688064
+ 136: 8.353884573498535
+ 3: 9.149231954274304
+ 138: 8.862645399076076
+ 139: 8.214180187653598
+ 140: 8.588138726132641
+ 141: 10.198459485868906
+ 4: 9.193189072829572
+ 143: 8.508633703895459
+ 5: 9.27580510928055
+ 145: 10.66391353468443
+ 6: 13.150535457623667
+ 147: 10.295798054152893
+ 148: 8.727375003658778
+ 149: 8.139860884835961
+ 150: 9.137884658727723
+ 151: 6.1993757689099445
+ 152: 5.379008416563386
+ 153: 5.999217405852915
+ 17: 7.471187849893246
+ 16: 8.412519228071288
+ 156: 6.511320139284012
+ 157: 6.495503052291176
+ 158: 6.4691245203276235
+ 18: 8.046436081223744
+ 7: 8.53023993361135
+ 161: 6.5695540715861345
+ 9: 7.06066679340818
+ 163: 6.126827317718087
+ 8: 7.391868149254668
+ 165: 5.850762331953799
+ 10: 6.392737721827344
+ 167: 6.701751274702547
+ 11: 7.108232079933884
+ 169: 6.498765197817774
+ 13: 7.377924918159498
+ 171: 6.8661972583787225
+ 12: 8.579455534855914
+ 173: 6.604260017556539
+ 14: 6.895344286814114
+ 175: 6.07934878085839
+ 20: 12.237940047506514
+ 177: 10.415865290819106
+ 178: 8.636449653407603
+ 22: 8.548684956545552
+ 180: 9.195403546434065
+ 181: 9.307120855626186
+ 21: 14.40623429205884
+ 183: 11.136075208306705
+ 184: 10.380179074464987
+ 23: 9.900986672111294
+ 186: 9.80179820389519
+ 187: 11.26701817845163
+ 188: 9.714741299910031
+ 189: 7.997210438319446
+ 190: 8.018485293706336
+ 191: 8.152387130926519
+ 192: 10.572139256377373
+ 193: 8.829320776371096
+ 194: 9.353915311261236
+ 195: 8.770069844546882
+ 196: 6.380411318934716
+ 197: 6.925252704835688
+ }
+ npe99_per_ldm {
+ 100: 31.796734331757474
+ 101: 30.156823307914355
+ 102: 28.085999280928775
+ 103: 27.772409859859803
+ 104: 27.063494458024618
+ 105: 26.818097394520635
+ 106: 25.30721172181605
+ 107: 26.91752125832992
+ 108: 26.856715591131337
+ 109: 25.068096485672058
+ 110: 25.487803402381296
+ 111: 25.22568444548434
+ 112: 25.97592571841177
+ 113: 25.719255333824044
+ 114: 26.344100968854626
+ 115: 27.565684593896133
+ 24: 25.958704292040675
+ 117: 25.027316620457235
+ 118: 24.153183750150713
+ 119: 23.454696148177824
+ 120: 23.650815452820808
+ 121: 22.974733432539477
+ 122: 22.7787999634705
+ 123: 22.95454955539594
+ 124: 23.536490612996054
+ 125: 23.949973181222603
+ 126: 24.156669818564882
+ 127: 24.307258463774165
+ 128: 26.230362982841182
+ 129: 27.279179727007598
+ 130: 27.658094700898353
+ 131: 29.498087172293605
+ 132: 31.031273144224297
+ 1: 21.815528592810452
+ 134: 18.552025618639288
+ 2: 16.112806045145106
+ 136: 16.26474105956125
+ 3: 16.03098434863021
+ 138: 15.331273499198128
+ 139: 14.571579406397491
+ 140: 15.17885461687226
+ 141: 18.84639764294039
+ 4: 15.14680909708014
+ 143: 15.959472764350828
+ 5: 16.782205693066487
+ 145: 19.544102261096246
+ 6: 22.703607587916377
+ 147: 18.72768083165344
+ 148: 16.949308129734572
+ 149: 15.403991541464318
+ 150: 16.53716758233348
+ 151: 10.225916118065296
+ 152: 9.87580178788986
+ 153: 11.078951829537761
+ 17: 11.432785302080758
+ 16: 13.77854390122661
+ 156: 11.423297506459893
+ 157: 11.649539675840872
+ 158: 10.298080549891448
+ 18: 12.980837586206968
+ 7: 15.380520840864325
+ 161: 13.675989650355017
+ 9: 13.996495073301666
+ 163: 12.2974182120915
+ 8: 12.70906602790459
+ 165: 11.986475872557232
+ 10: 12.564713827155238
+ 167: 13.186734914311163
+ 11: 12.050553315262324
+ 169: 11.60835366208721
+ 13: 13.926666792959946
+ 171: 13.430274209070587
+ 12: 16.136339266834494
+ 173: 14.083254962467871
+ 14: 13.786860675398353
+ 175: 11.820493814517507
+ 20: 31.456220720812258
+ 177: 17.145310081719785
+ 178: 14.383365993773618
+ 22: 14.266997881870614
+ 180: 14.767809590251009
+ 181: 16.5761236909575
+ 21: 34.75300701148773
+ 183: 21.53482636836683
+ 184: 19.21124669677774
+ 23: 17.370155603691355
+ 186: 18.523203073422827
+ 187: 22.25607277080702
+ 188: 22.550397440214464
+ 189: 14.875857373605129
+ 190: 14.056370833704333
+ 191: 14.59218011999597
+ 192: 24.578201317085025
+ 193: 17.409549388507017
+ 194: 17.03090444462318
+ 195: 16.907666081457222
+ 196: 13.586781795190145
+ 197: 13.688773958708495
+ }
+ }
+ headpose {
+ mae_ypr: [1.8035515527812418, 1.8859887869300525, 0.969774915452454]
+ mae_mean: 1.5531050850545827
+ acc_label: [0.888, 0.8716, 0.9472]
+ acc_adj_label: [0.9988, 1.0, 1.0]
+ sr_yaw {
+ 2.5: 0.7692
+ 5: 0.95
+ 10: 0.9952
+ 15: 0.998
+ 30: 1.0
+ }
+ sr_pitch {
+ 2.5: 0.7332
+ 5: 0.9504
+ 10: 0.996
+ 15: 0.9992
+ 30: 1.0
+ }
+ sr_roll {
+ 2.5: 0.932
+ 5: 0.9876
+ 10: 0.9996
+ 15: 1.0
+ 30: 1.0
+ }
+ }
+ }
diff --git a/SPIGA/spiga/eval/results/wflw/results_wflw_test.json b/SPIGA/spiga/eval/results/wflw/results_wflw_test.json
new file mode 100644
index 0000000000000000000000000000000000000000..720fbd0403efe72e152181962cf2e35eb80fc1b0
--- /dev/null
+++ b/SPIGA/spiga/eval/results/wflw/results_wflw_test.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:943cb588521b73c0c8ddab4cf297a6b7653006a93568808c63389611a731f74c
+size 13109839
diff --git a/SPIGA/spiga/eval/results_gen.py b/SPIGA/spiga/eval/results_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c20995eaf1d174f729fae168cbc509c1732e2ff
--- /dev/null
+++ b/SPIGA/spiga/eval/results_gen.py
@@ -0,0 +1,87 @@
+import pkg_resources
+import json
+import copy
+import torch
+
+import spiga.data.loaders.dl_config as dl_cfg
+import spiga.data.loaders.dataloader as dl
+import spiga.inference.pretreatment as pretreat
+from spiga.inference.framework import SPIGAFramework
+from spiga.inference.config import ModelConfig
+
+
+def main():
+ import argparse
+ pars = argparse.ArgumentParser(description='Experiment results generator')
+ pars.add_argument('database', type=str, help='Database name',
+ choices=['wflw', '300wpublic', '300wprivate', "merlrav", "cofw68"])
+ pars.add_argument('-a','--anns', type=str, default='test', help='Annotations type: test, valid or train')
+ pars.add_argument('--gpus', type=int, default=0, help='GPU Id')
+ args = pars.parse_args()
+
+ # Load model framework
+ model_cfg = ModelConfig(args.database)
+ model_framework = SPIGAFramework(model_cfg, gpus=[args.gpus])
+
+ # Generate results
+ tester = Tester(model_framework, args.database, anns_type=args.anns)
+ with torch.no_grad():
+ tester.generate_results()
+
+
+class Tester:
+
+ def __init__(self, model_framework, database, anns_type='test'):
+
+ # Parameters
+ self.anns_type = anns_type
+ self.database = database
+
+ # Model initialization
+ self.model_framework = model_framework
+
+ # Dataloader
+ self.dl_eval = dl_cfg.AlignConfig(self.database, mode=self.anns_type)
+ self.dl_eval.aug_names = []
+ self.dl_eval.shuffle = False
+ self.dl_eval.target_dist = self.model_framework.model_cfg.target_dist
+ self.dl_eval.image_size = self.model_framework.model_cfg.image_size
+ self.dl_eval.ftmap_size = self.model_framework.model_cfg.ftmap_size
+
+ self.batch_size = 1
+ self.test_data, _ = dl.get_dataloader(self.batch_size, self.dl_eval,
+ pretreat=pretreat.NormalizeAndPermute(), debug=True)
+
+ # Results
+ self.data_struc = {'imgpath': str, 'bbox': None, 'headpose': None, 'ids': None, 'landmarks': None, 'visible': None}
+ self.result_path = pkg_resources.resource_filename('spiga', 'eval/results')
+ self.result_file = '/results_%s_%s.json' % (self.database, self.anns_type)
+ self.file_out = self.result_path + self.result_file
+
+ def generate_results(self):
+
+ data = []
+ for step, batch in enumerate(self.test_data):
+ print('Step: ', step)
+ inputs = self.model_framework.select_inputs(batch)
+ outputs_raw = self.model_framework.net_forward(inputs)
+ # Postprocessing
+ outputs = self.model_framework.postreatment(outputs_raw, batch['bbox'], batch['bbox_raw'])
+
+ # Data
+ data_dict = copy.deepcopy(self.data_struc)
+ data_dict['imgpath'] = batch['imgpath_local'][0]
+ data_dict['bbox'] = batch['bbox_raw'][0].numpy().tolist()
+ data_dict['visible'] = batch['visible'][0].numpy().tolist()
+ data_dict['ids'] = self.dl_eval.database.ldm_ids
+ data_dict['landmarks'] = outputs['landmarks'][0]
+ data_dict['headpose'] = outputs['headpose'][0]
+ data.append(data_dict)
+
+ # Save outputs
+ with open(self.file_out, 'w') as outfile:
+ json.dump(data, outfile)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/SPIGA/spiga/inference/__init__.py b/SPIGA/spiga/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/inference/config.py b/SPIGA/spiga/inference/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6b881852fa1feed3ffe14aab9eb23cf89b750ed
--- /dev/null
+++ b/SPIGA/spiga/inference/config.py
@@ -0,0 +1,58 @@
+from collections import OrderedDict
+
+from spiga.data.loaders.dl_config import DatabaseStruct
+
+MODELS_URL = {'wflw': 'https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7',
+ '300wpublic': 'https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC',
+ '300wprivate': 'https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM',
+ 'merlrav': 'https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6',
+ 'cofw68': 'https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM'}
+
+
+class ModelConfig(object):
+
+ def __init__(self, dataset_name=None, load_model_url=True):
+ # Model configuration
+ self.model_weights = None
+ self.model_weights_path = None
+ self.load_model_url = load_model_url
+ self.model_weights_url = None
+ # Pretreatment
+ self.focal_ratio = 1.5 # Camera matrix focal length ratio.
+ self.target_dist = 1.6 # Target distance zoom in/out around face.
+ self.image_size = (256, 256)
+ # Outputs
+ self.ftmap_size = (64, 64)
+ # Dataset
+ self.dataset = None
+
+ if dataset_name is not None:
+ self.update_with_dataset(dataset_name)
+
+ def update_with_dataset(self, dataset_name):
+
+ config_dict = {'dataset': DatabaseStruct(dataset_name),
+ 'model_weights': 'spiga_%s.pt' % dataset_name}
+
+ if dataset_name == 'cofw68': # Test only
+ config_dict['model_weights'] = 'spiga_300wprivate.pt'
+
+ if self.load_model_url:
+ config_dict['model_weights_url'] = MODELS_URL[dataset_name]
+
+ self.update(config_dict)
+
+ def update(self, params_dict):
+ state_dict = self.state_dict()
+ for k, v in params_dict.items():
+ if k in state_dict or hasattr(self, k):
+ setattr(self, k, v)
+ else:
+ raise Warning('Unknown option: {}: {}'.format(k, v))
+
+ def state_dict(self):
+ state_dict = OrderedDict()
+ for k in self.__dict__.keys():
+ if not k.startswith('_'):
+ state_dict[k] = getattr(self, k)
+ return state_dict
diff --git a/SPIGA/spiga/inference/framework.py b/SPIGA/spiga/inference/framework.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f1f992539506b5743cf57ea30c5c8efdb6263f
--- /dev/null
+++ b/SPIGA/spiga/inference/framework.py
@@ -0,0 +1,145 @@
+from spiga.inference.config import ModelConfig
+from spiga.models.spiga import SPIGA
+import spiga.inference.pretreatment as pretreat
+import os
+import pkg_resources
+import copy
+import torch
+import numpy as np
+
+# Paths
+weights_path_dft = pkg_resources.resource_filename('spiga', 'models/weights')
+
+
+class SPIGAFramework:
+
+ def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True):
+
+ # Parameters
+ self.model_cfg = model_cfg
+ self.gpus = gpus
+
+ # Pretreatment initialization
+ self.transforms = pretreat.get_transformers(self.model_cfg)
+
+ # SPIGA model
+ self.model_inputs = ['image', "model3d", "cam_matrix"]
+ self.model = SPIGA(num_landmarks=model_cfg.dataset.num_landmarks,
+ num_edges=model_cfg.dataset.num_edges)
+
+ # Load weights and set model
+ weights_path = self.model_cfg.model_weights_path
+ if weights_path is None:
+ weights_path = weights_path_dft
+
+ if self.model_cfg.load_model_url:
+ model_state_dict = torch.hub.load_state_dict_from_url(self.model_cfg.model_weights_url,
+ model_dir=weights_path,
+ file_name=self.model_cfg.model_weights)
+ else:
+ weights_file = os.path.join(
+ weights_path, self.model_cfg.model_weights)
+ model_state_dict = torch.load(weights_file)
+
+ self.model.load_state_dict(model_state_dict)
+ # self.model = self.model.cuda(gpus[0])
+ self.model = self.model.cuda(
+ gpus[0]) if torch.cuda.is_available() else self.model
+ self.model.eval()
+ print('SPIGA model loaded!')
+
+ # Load 3D model and camera intrinsic matrix
+ if load3DM:
+ loader_3DM = pretreat.AddModel3D(model_cfg.dataset.ldm_ids,
+ ftmap_size=model_cfg.ftmap_size,
+ focal_ratio=model_cfg.focal_ratio,
+ totensor=True)
+ params_3DM = self._data2device(loader_3DM())
+ self.model3d = params_3DM['model3d']
+ self.cam_matrix = params_3DM['cam_matrix']
+
+ def inference(self, image, bboxes):
+ """
+ @param self:
+ @param image: Raw image
+ @param bboxes: List of bounding box founded on the image [[x,y,w,h],...]
+ @return: features dict {'landmarks': list with shape (num_bbox, num_landmarks, 2) and x,y referred to image size
+ 'headpose': list with shape (num_bbox, 6) euler->[:3], trl->[3:]
+ """
+ batch_crops, crop_bboxes = self.pretreat(image, bboxes)
+ outputs = self.net_forward(batch_crops)
+ features = self.postreatment(outputs, crop_bboxes, bboxes)
+ return features
+
+ def pretreat(self, image, bboxes):
+ crop_bboxes = []
+ crop_images = []
+ for bbox in bboxes:
+ sample = {'image': copy.deepcopy(image),
+ 'bbox': copy.deepcopy(bbox)}
+ sample_crop = self.transforms(sample)
+ crop_bboxes.append(sample_crop['bbox'])
+ crop_images.append(sample_crop['image'])
+
+ # Images to tensor and device
+ batch_images = torch.tensor(np.array(crop_images), dtype=torch.float)
+ batch_images = self._data2device(batch_images)
+ # Batch 3D model and camera intrinsic matrix
+ batch_model3D = self.model3d.unsqueeze(0).repeat(len(bboxes), 1, 1)
+ batch_cam_matrix = self.cam_matrix.unsqueeze(
+ 0).repeat(len(bboxes), 1, 1)
+
+ # SPIGA inputs
+ model_inputs = [batch_images, batch_model3D, batch_cam_matrix]
+ return model_inputs, crop_bboxes
+
+ def net_forward(self, inputs):
+ outputs = self.model(inputs)
+ return outputs
+
+ def postreatment(self, output, crop_bboxes, bboxes):
+ features = {}
+ crop_bboxes = np.array(crop_bboxes)
+ bboxes = np.array(bboxes)
+
+ if 'Landmarks' in output.keys():
+ landmarks = output['Landmarks'][-1].cpu().detach().numpy()
+ landmarks = landmarks.transpose((1, 0, 2))
+ landmarks = landmarks*self.model_cfg.image_size
+ landmarks_norm = (
+ landmarks - crop_bboxes[:, 0:2]) / crop_bboxes[:, 2:4]
+ landmarks_out = (landmarks_norm * bboxes[:, 2:4]) + bboxes[:, 0:2]
+ landmarks_out = landmarks_out.transpose((1, 0, 2))
+ features['landmarks'] = landmarks_out.tolist()
+
+ # Pose output
+ if 'Pose' in output.keys():
+ pose = output['Pose'].cpu().detach().numpy()
+ features['headpose'] = pose.tolist()
+
+ return features
+
+ def select_inputs(self, batch):
+ inputs = []
+ for ft_name in self.model_inputs:
+ data = batch[ft_name]
+ inputs.append(self._data2device(data.type(torch.float)))
+ return inputs
+
+ def _data2device(self, data):
+ if isinstance(data, list):
+ data_var = data
+ for data_id, v_data in enumerate(data):
+ data_var[data_id] = self._data2device(v_data)
+ if isinstance(data, dict):
+ data_var = data
+ for k, v in data.items():
+ data[k] = self._data2device(v)
+ else:
+ with torch.no_grad():
+ if torch.cuda.is_available():
+ data_var = data.cuda(
+ device=self.gpus[0], non_blocking=True)
+ else:
+ data_var = data
+ return data_var
diff --git a/SPIGA/spiga/inference/pretreatment.py b/SPIGA/spiga/inference/pretreatment.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2099056356301d2f9300b344f6273f4c75b5b31
--- /dev/null
+++ b/SPIGA/spiga/inference/pretreatment.py
@@ -0,0 +1,31 @@
+from torchvision import transforms
+import numpy as np
+from PIL import Image
+import cv2
+
+from spiga.data.loaders.transforms import TargetCrop, ToOpencv, AddModel3D
+
+
+def get_transformers(data_config):
+ transformer_seq = [
+ Opencv2Pil(),
+ TargetCrop(data_config.image_size, data_config.target_dist),
+ ToOpencv(),
+ NormalizeAndPermute()]
+ return transforms.Compose(transformer_seq)
+
+
+class NormalizeAndPermute:
+ def __call__(self, sample):
+ image = np.array(sample['image'], dtype=float)
+ image = np.transpose(image, (2, 0, 1))
+ sample['image'] = image / 255
+ return sample
+
+
+class Opencv2Pil:
+ def __call__(self, sample):
+ image = cv2.cvtColor(sample['image'], cv2.COLOR_BGR2RGB)
+ sample['image'] = Image.fromarray(image)
+ return sample
+
diff --git a/SPIGA/spiga/models/__init__.py b/SPIGA/spiga/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/models/cnn/__init__.py b/SPIGA/spiga/models/cnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/models/cnn/cnn_multitask.py b/SPIGA/spiga/models/cnn/cnn_multitask.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55694b513231d57fe2456b34cf2b65d82c7140e
--- /dev/null
+++ b/SPIGA/spiga/models/cnn/cnn_multitask.py
@@ -0,0 +1,94 @@
+from torch import nn
+from spiga.models.cnn.layers import Conv, Residual
+from spiga.models.cnn.hourglass import HourglassCore
+from spiga.models.cnn.coord_conv import AddCoordsTh
+from spiga.models.cnn.transform_e2p import E2Ptransform
+
+
+class MultitaskCNN(nn.Module):
+
+ def __init__(self, nstack=4, num_landmarks=98, num_edges=15, pose_req=True, **kwargs):
+ super(MultitaskCNN, self).__init__()
+
+ # Parameters
+ self.img_res = 256 # WxH input resolution
+ self.ch_dim = 256 # Default channel dimension
+ self.out_res = 64 # WxH output resolution
+ self.nstack = nstack # Hourglass modules stacked
+ self.num_landmarks = num_landmarks # Number of landmarks
+ self.num_edges = num_edges # Number of edges subsets (eyeR, eyeL, nose, etc)
+ self.pose_required = pose_req # Multitask flag
+
+ # Image preprocessing
+ self.pre = nn.Sequential(
+ AddCoordsTh(x_dim=self.img_res, y_dim=self.img_res, with_r=True),
+ Conv(6, 64, 7, 2, bn=True, relu=True),
+ Residual(64, 128),
+ Conv(128, 128, 2, 2, bn=True, relu=True),
+ Residual(128, 128),
+ Residual(128, self.ch_dim)
+ )
+
+ # Hourglass modules
+ self.hgs = nn.ModuleList([HourglassCore(4, self.ch_dim) for i in range(self.nstack)])
+ self.hgs_out = nn.ModuleList([
+ nn.Sequential(
+ Residual(self.ch_dim, self.ch_dim),
+ Conv(self.ch_dim, self.ch_dim, 1, bn=True, relu=True)
+ ) for i in range(nstack)])
+ if self.pose_required:
+ self.hgs_core = nn.ModuleList([
+ nn.Sequential(
+ Residual(self.ch_dim, self.ch_dim),
+ Conv(self.ch_dim, self.ch_dim, 2, 2, bn=True, relu=True),
+ Residual(self.ch_dim, self.ch_dim),
+ Conv(self.ch_dim, self.ch_dim, 2, 2, bn=True, relu=True)
+ ) for i in range(nstack)])
+
+ # Attention module (ADnet style)
+ self.outs_points = nn.ModuleList([nn.Sequential(Conv(self.ch_dim, self.num_landmarks, 1, relu=False, bn=False),
+ nn.Sigmoid()) for i in range(self.nstack - 1)])
+ self.outs_edges = nn.ModuleList([nn.Sequential(Conv(self.ch_dim, self.num_edges, 1, relu=False, bn=False),
+ nn.Sigmoid()) for i in range(self.nstack - 1)])
+ self.E2Ptransform = E2Ptransform(self.num_landmarks, self.num_edges, out_dim=self.out_res)
+
+ self.outs_features = nn.ModuleList([Conv(self.ch_dim, self.num_landmarks, 1, relu=False, bn=False)for i in range(self.nstack - 1)])
+
+ # Stacked Hourglass inputs (nstack > 1)
+ self.merge_preds = nn.ModuleList([Conv(self.num_landmarks, self.ch_dim, 1, relu=False, bn=False) for i in range(self.nstack - 1)])
+ self.merge_features = nn.ModuleList([Conv(self.ch_dim, self.ch_dim, 1, relu=False, bn=False) for i in range(self.nstack - 1)])
+
+ def forward(self, imgs):
+
+ x = self.pre(imgs)
+ outputs = {'VisualField': [],
+ 'HGcore': []}
+
+ core_raw = []
+ for i in range(self.nstack):
+ # Hourglass
+ hg, core_raw = self.hgs[i](x, core=core_raw)
+ if self.pose_required:
+ core = self.hgs_core[i](core_raw[-self.hgs[i].n])
+ outputs['HGcore'].append(core)
+ hg = self.hgs_out[i](hg)
+
+ # Visual features
+ outputs['VisualField'].append(hg)
+
+ # Prepare next stacked input
+ if i < self.nstack - 1:
+ # Attentional modules
+ points = self.outs_points[i](hg)
+ edges = self.outs_edges[i](hg)
+ edges_ext = self.E2Ptransform(edges)
+ point_edges = points * edges_ext
+
+ # Landmarks
+ maps = self.outs_features[i](hg)
+ preds = maps * point_edges
+
+ # Outputs
+ x = x + self.merge_preds[i](preds) + self.merge_features[i](hg)
+
+ return outputs
diff --git a/SPIGA/spiga/models/cnn/coord_conv.py b/SPIGA/spiga/models/cnn/coord_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..45e970ffd3a5b9b0d94870d4193afe64a1222b35
--- /dev/null
+++ b/SPIGA/spiga/models/cnn/coord_conv.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn as nn
+
+
+class AddCoordsTh(nn.Module):
+ def __init__(self, x_dim=64, y_dim=64, with_r=False):
+ super(AddCoordsTh, self).__init__()
+ self.x_dim = x_dim
+ self.y_dim = y_dim
+ self.with_r = with_r
+
+ xx_channel, yy_channel = self._prepare_coords()
+ self.xx_channel = nn.parameter.Parameter(xx_channel, requires_grad=False)
+ self.yy_channel = nn.parameter.Parameter(yy_channel, requires_grad=False)
+
+ def _prepare_coords(self):
+ xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32)
+ xx_ones = xx_ones.unsqueeze(-1)
+
+ xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0)
+ xx_range = xx_range.unsqueeze(1)
+
+ xx_channel = torch.matmul(xx_ones, xx_range)
+ xx_channel = xx_channel.unsqueeze(-1)
+
+ yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32)
+ yy_ones = yy_ones.unsqueeze(1)
+
+ yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0)
+ yy_range = yy_range.unsqueeze(-1)
+
+ yy_channel = torch.matmul(yy_range, yy_ones)
+ yy_channel = yy_channel.unsqueeze(-1)
+
+ xx_channel = xx_channel.permute(0, 3, 2, 1)
+ yy_channel = yy_channel.permute(0, 3, 2, 1)
+
+ xx_channel = xx_channel.float() / (self.x_dim - 1)
+ yy_channel = yy_channel.float() / (self.y_dim - 1)
+
+ xx_channel = xx_channel * 2 - 1
+ yy_channel = yy_channel * 2 - 1
+ return xx_channel, yy_channel
+
+ def forward(self, input_tensor):
+ """
+ input_tensor: (batch, c, x_dim, y_dim)
+ """
+ batch_size_tensor = input_tensor.shape[0]
+ xx_channel = self.xx_channel.repeat(batch_size_tensor, 1, 1, 1)
+ yy_channel = self.yy_channel.repeat(batch_size_tensor, 1, 1, 1)
+ ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
+
+ if self.with_r:
+ rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
+ ret = torch.cat([ret, rr], dim=1)
+
+ return ret
diff --git a/SPIGA/spiga/models/cnn/hourglass.py b/SPIGA/spiga/models/cnn/hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3fb01c410c8df97c1ab0e2ab82b0b0ba04dcc91
--- /dev/null
+++ b/SPIGA/spiga/models/cnn/hourglass.py
@@ -0,0 +1,53 @@
+import torch.nn as nn
+
+from spiga.models.cnn.layers import Conv, Deconv, Residual
+
+
+class Hourglass(nn.Module):
+ def __init__(self, n, f, bn=None, increase=0):
+ super(Hourglass, self).__init__()
+ nf = f + increase
+ self.up1 = Residual(f, f)
+ # Lower branch
+ self.pool1 = Conv(f, f, 2, 2, bn=True, relu=True)
+ self.low1 = Residual(f, nf)
+ self.n = n
+ # Recursive hourglass
+ if self.n > 1:
+ self.low2 = Hourglass(n - 1, nf, bn=bn)
+ else:
+ self.low2 = Residual(nf, nf)
+ self.low3 = Residual(nf, f)
+ self.up2 = Deconv(f, f, 2, 2, bn=True, relu=True)
+
+ def forward(self, x):
+ up1 = self.up1(x)
+ pool1 = self.pool1(x)
+ low1 = self.low1(pool1)
+ low2 = self.low2(low1)
+ low3 = self.low3(low2)
+ up2 = self.up2(low3)
+ return up1 + up2
+
+
+class HourglassCore(Hourglass):
+ def __init__(self, n, f, bn=None, increase=0):
+ super(HourglassCore, self).__init__(n, f, bn=bn, increase=increase)
+ nf = f + increase
+ if self.n > 1:
+ self.low2 = HourglassCore(n - 1, nf, bn=bn)
+
+ def forward(self, x, core=[]):
+ up1 = self.up1(x)
+ pool1 = self.pool1(x)
+ low1 = self.low1(pool1)
+ if self.n > 1:
+ low2, core = self.low2(low1, core=core)
+ else:
+ low2 = self.low2(low1)
+ core.append(low2)
+ low3 = self.low3(low2)
+ if self.n > 1:
+ core.append(low3)
+ up2 = self.up2(low3)
+ return up1 + up2, core
diff --git a/SPIGA/spiga/models/cnn/layers.py b/SPIGA/spiga/models/cnn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..36bcf16c7812a6824b115272ef4a8712015709c4
--- /dev/null
+++ b/SPIGA/spiga/models/cnn/layers.py
@@ -0,0 +1,79 @@
+from torch import nn
+
+
+class Conv(nn.Module):
+ def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True):
+ super(Conv, self).__init__()
+ self.inp_dim = inp_dim
+ self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=False)
+ self.relu = None
+ self.bn = None
+ if relu:
+ self.relu = nn.ReLU()
+ if bn:
+ self.bn = nn.BatchNorm2d(out_dim)
+
+ def forward(self, x):
+ assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
+ x = self.conv(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ if self.relu is not None:
+ x = self.relu(x)
+ return x
+
+
+class Deconv(nn.Module):
+ def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True):
+ super(Deconv, self).__init__()
+ self.inp_dim = inp_dim
+ self.deconv = nn.ConvTranspose2d(inp_dim, out_dim, kernel_size=kernel_size, stride=stride, bias=False)
+ self.relu = None
+ self.bn = None
+ if relu:
+ self.relu = nn.ReLU()
+ if bn:
+ self.bn = nn.BatchNorm2d(out_dim)
+
+ def forward(self, x):
+ assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
+ x = self.deconv(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ if self.relu is not None:
+ x = self.relu(x)
+ return x
+
+
+class Residual(nn.Module):
+ def __init__(self, inp_dim, out_dim, kernel=3):
+ super(Residual, self).__init__()
+ self.relu = nn.ReLU()
+ self.bn1 = nn.BatchNorm2d(inp_dim)
+ self.conv1 = Conv(inp_dim, int(out_dim / 2), 1, relu=False)
+ self.bn2 = nn.BatchNorm2d(int(out_dim / 2))
+ self.conv2 = Conv(int(out_dim / 2), int(out_dim / 2), kernel, relu=False)
+ self.bn3 = nn.BatchNorm2d(int(out_dim / 2))
+ self.conv3 = Conv(int(out_dim / 2), out_dim, 1, relu=False)
+ self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
+ if inp_dim == out_dim:
+ self.need_skip = False
+ else:
+ self.need_skip = True
+
+ def forward(self, x):
+ if self.need_skip:
+ residual = self.skip_layer(x)
+ else:
+ residual = x
+ out = self.bn1(x)
+ out = self.relu(out)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+ out += residual
+ return out
diff --git a/SPIGA/spiga/models/cnn/transform_e2p.py b/SPIGA/spiga/models/cnn/transform_e2p.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c42534bf5f608bd34f576c5ac1f6e8c0eb6167
--- /dev/null
+++ b/SPIGA/spiga/models/cnn/transform_e2p.py
@@ -0,0 +1,257 @@
+import torch
+from torch import nn
+
+
+class E2Ptransform(nn.Module):
+ """Edge to Points trasnformation"""
+ def __init__(self, points, edges, out_dim=64):
+ super(E2Ptransform, self).__init__()
+ self.ones = nn.parameter.Parameter(torch.ones((1, out_dim, out_dim)), requires_grad=False)
+ edge_matrix = self._select_matrix(points, edges)
+ self.edge2point = nn.parameter.Parameter(edge_matrix, requires_grad=False) # Npoint X Nedges+1
+
+ def forward(self, edges):
+ B, L, H, W = edges.shape
+ edges_ext = torch.cat((edges, self.ones.repeat(B, 1, 1, 1)), 1)
+ edges_mat = edges_ext.permute(0, 2, 3, 1).reshape(B, H, W, 1, L+1)
+ edge2point = self.edge2point.transpose(-1, -2)
+ point_edges = torch.matmul(edges_mat, edge2point)
+ point_edges = point_edges.reshape(B, H, W, -1).permute(0, 3, 1, 2)
+ point_edges[point_edges > 1] = 1.
+ return point_edges
+
+ def _select_matrix(self, points, edges):
+
+ if points == 98 and edges == 15:
+ return WFLW_98x15
+ elif points == 68 and edges == 13:
+ return W300_68x13
+ elif points == 29 and edges == 13:
+ return COFW_29x13
+ elif points == 19 and edges == 6:
+ return AFLW19_19x6
+ else:
+ raise ValueError("E2P matrix not implemented")
+
+
+# Database matrixE2P
+WFLW_98x15 = torch.Tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
+
+
+W300_68x13 = torch.Tensor([ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
+
+
+AFLW19_19x6 = torch.Tensor([[1, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 1]])
+
+
+COFW_29x13 = torch.Tensor([ [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
diff --git a/SPIGA/spiga/models/gnn/__init__.py b/SPIGA/spiga/models/gnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SPIGA/spiga/models/gnn/gat.py b/SPIGA/spiga/models/gnn/gat.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7816f22a4b3afd77d3f2d3e69bc65e45b026a14
--- /dev/null
+++ b/SPIGA/spiga/models/gnn/gat.py
@@ -0,0 +1,62 @@
+from copy import deepcopy
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from spiga.models.gnn.layers import MLP
+
+
+class GAT(nn.Module):
+ def __init__(self, input_dim: int, output_dim: int, num_heads=4):
+ super().__init__()
+
+ num_heads_in = num_heads
+ self.reshape = None
+ if input_dim != output_dim:
+ for num_heads_in in range(num_heads, 0, -1):
+ if input_dim % num_heads_in == 0:
+ break
+ self.reshape = MLP([input_dim, output_dim])
+
+ self.attention = MessagePassing(input_dim, num_heads_in, out_dim=output_dim)
+
+ def forward(self, features):
+ message, prob = self.attention(features)
+ if self.reshape:
+ features = self.reshape(features)
+ output = features + message
+ return output, prob
+
+
+class MessagePassing(nn.Module):
+ def __init__(self, feature_dim: int, num_heads: int, out_dim=None):
+ super().__init__()
+ self.attn = Attention(num_heads, feature_dim)
+ self.mlp = MLP([feature_dim*2, feature_dim*2, out_dim])
+
+ def forward(self, features):
+ message, prob = self.attn(features, features, features)
+ return self.mlp(torch.cat([features, message], dim=1)), prob
+
+
+class Attention(nn.Module):
+ def __init__(self, num_heads: int, feature_dim: int):
+ super().__init__()
+ assert feature_dim % num_heads == 0
+ self.dim = feature_dim // num_heads
+ self.num_heads = num_heads
+ self.merge = nn.Conv1d(feature_dim, feature_dim, kernel_size=1)
+ self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
+
+ def forward(self, query, key, value):
+ batch_dim = query.size(0)
+ query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
+ for l, x in zip(self.proj, (query, key, value))]
+ x, prob = self.attention(query, key, value)
+ return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)), prob
+
+ def attention(self, query, key, value):
+ dim = query.shape[1]
+ scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5
+ prob = F.softmax(scores, dim=-1)
+ return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
diff --git a/SPIGA/spiga/models/gnn/layers.py b/SPIGA/spiga/models/gnn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..31137f40b51fb73aee26fa5705dd3661b472fe37
--- /dev/null
+++ b/SPIGA/spiga/models/gnn/layers.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+
+def MLP(channels: list):
+ n = len(channels)
+ layers = []
+ for i in range(1, n):
+ layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
+ if i < (n-1):
+ layers.append(nn.BatchNorm1d(channels[i]))
+ layers.append(nn.ReLU())
+ return nn.Sequential(*layers)
diff --git a/SPIGA/spiga/models/gnn/pose_proj.py b/SPIGA/spiga/models/gnn/pose_proj.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfa37e9e8bb141671a33d019c206d6612a41333f
--- /dev/null
+++ b/SPIGA/spiga/models/gnn/pose_proj.py
@@ -0,0 +1,77 @@
+import torch
+import math
+
+
+def euler_to_rotation_matrix(euler):
+ # http://euclideanspace.com/maths/geometry/rotations/conversions/eulerToMatrix/index.htm
+ # Change coordinates system
+
+ euler[:, 0] = -(euler[:, 0]-90)
+ euler[:, 1] = -euler[:, 1]
+ euler[:, 2] = -(euler[:, 2]+90)
+
+ # Convert to radians
+ rad = euler*(math.pi/180.0)
+ cy = torch.cos(rad[:, 0])
+ sy = torch.sin(rad[:, 0])
+ cp = torch.cos(rad[:, 1])
+ sp = torch.sin(rad[:, 1])
+ cr = torch.cos(rad[:, 2])
+ sr = torch.sin(rad[:, 2])
+
+ # Init R matrix tensors
+ working_device = None
+ if euler.is_cuda:
+ working_device = euler.device
+ Ry = torch.zeros((euler.shape[0], 3, 3), device=working_device)
+ Rp = torch.zeros((euler.shape[0], 3, 3), device=working_device)
+ Rr = torch.zeros((euler.shape[0], 3, 3), device=working_device)
+
+ # Yaw
+ Ry[:, 0, 0] = cy
+ Ry[:, 0, 2] = sy
+ Ry[:, 1, 1] = 1.
+ Ry[:, 2, 0] = -sy
+ Ry[:, 2, 2] = cy
+
+ # Pitch
+ Rp[:, 0, 0] = cp
+ Rp[:, 0, 1] = -sp
+ Rp[:, 1, 0] = sp
+ Rp[:, 1, 1] = cp
+ Rp[:, 2, 2] = 1.
+
+ # Roll
+ Rr[:, 0, 0] = 1.
+ Rr[:, 1, 1] = cr
+ Rr[:, 1, 2] = -sr
+ Rr[:, 2, 1] = sr
+ Rr[:, 2, 2] = cr
+
+ return torch.matmul(torch.matmul(Ry, Rp), Rr)
+
+
+def projectPoints(pts, rot, trl, cam_matrix):
+
+ # Get working device
+ working_device = None
+ if pts.is_cuda:
+ working_device = pts.device
+
+ # Perspective projection model
+ trl = trl.unsqueeze(2)
+ extrinsics = torch.cat((rot, trl), 2)
+ proj_matrix = torch.matmul(cam_matrix, extrinsics)
+
+ # Homogeneous landmarks
+ ones = torch.ones(pts.shape[:2], device=working_device, requires_grad=trl.requires_grad)
+ ones = ones.unsqueeze(2)
+ pts_hom = torch.cat((pts, ones), 2)
+
+ # Project landmarks
+ pts_proj = pts_hom.permute((0, 2, 1)) # Transpose
+ pts_proj = torch.matmul(proj_matrix, pts_proj)
+ pts_proj = pts_proj.permute((0, 2, 1))
+ pts_proj = pts_proj/pts_proj[:, :, 2].unsqueeze(2) # Lambda = 1
+
+ return pts_proj[:, :, :-1]
diff --git a/SPIGA/spiga/models/gnn/step_regressor.py b/SPIGA/spiga/models/gnn/step_regressor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6396590490546e882eab987d40b1ba9078922c5
--- /dev/null
+++ b/SPIGA/spiga/models/gnn/step_regressor.py
@@ -0,0 +1,43 @@
+import torch.nn as nn
+
+from spiga.models.gnn.layers import MLP
+from spiga.models.gnn.gat import GAT
+
+
+class StepRegressor(nn.Module):
+
+ def __init__(self, input_dim: int, feature_dim: int, nstack=4, decoding=[256, 128, 64, 32]):
+ super(StepRegressor, self).__init__()
+ assert nstack > 0
+ self.nstack = nstack
+ self.gat = nn.ModuleList([GAT(input_dim, feature_dim, 4)])
+ for _ in range(nstack-1):
+ self.gat.append(GAT(feature_dim, feature_dim, 4))
+ self.decoder = OffsetDecoder(feature_dim, decoding)
+
+ def forward(self, embedded, prob_list=[]):
+ embedded = embedded.transpose(-1, -2)
+ for i in range(self.nstack):
+ embedded, prob = self.gat[i](embedded)
+ prob_list.append(prob)
+ offset = self.decoder(embedded)
+ return offset.transpose(-1, -2), prob_list
+
+
+class OffsetDecoder(nn.Module):
+ def __init__(self, feature_dim, layers):
+ super().__init__()
+ self.decoder = MLP([feature_dim] + layers + [2])
+
+ def forward(self, embedded):
+ return self.decoder(embedded)
+
+
+class RelativePositionEncoder(nn.Module):
+ def __init__(self, input_dim, feature_dim, layers):
+ super().__init__()
+ self.encoder = MLP([input_dim] + layers + [feature_dim])
+
+ def forward(self, feature):
+ feature = feature.transpose(-1, -2)
+ return self.encoder(feature).transpose(-1, -2)
diff --git a/SPIGA/spiga/models/spiga.py b/SPIGA/spiga/models/spiga.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c72a36f2a45a4344665980bcf90e94a62766e2c
--- /dev/null
+++ b/SPIGA/spiga/models/spiga.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import spiga.models.gnn.pose_proj as pproj
+from spiga.models.cnn.cnn_multitask import MultitaskCNN
+from spiga.models.gnn.step_regressor import StepRegressor, RelativePositionEncoder
+
+
+class SPIGA(nn.Module):
+ def __init__(self, num_landmarks=98, num_edges=15, steps=3, **kwargs):
+
+ super(SPIGA, self).__init__()
+
+ # Model parameters
+ self.steps = steps # Cascaded regressors
+ self.embedded_dim = 512 # GAT input channel
+ self.nstack = 4 # Number of stacked GATs per step
+ self.kwindow = 7 # Output cropped window dimension (kernel)
+ self.swindow = 0.25 # Scale of the cropped window at first step (Dft. 25% w.r.t the input featuremap)
+ self.offset_ratio = [self.swindow/(2**step)/2 for step in range(self.steps)]
+
+ # CNN parameters
+ self.num_landmarks = num_landmarks
+ self.num_edges = num_edges
+
+ # Initialize backbone
+ self.visual_cnn = MultitaskCNN(num_landmarks=self.num_landmarks, num_edges=self.num_edges)
+ # Features dimensions
+ self.img_res = self.visual_cnn.img_res
+ self.visual_res = self.visual_cnn.out_res
+ self.visual_dim = self.visual_cnn.ch_dim
+
+ # Initialize Pose head
+ self.channels_pose = 6
+ self.pose_fc = nn.Linear(self.visual_cnn.ch_dim, self.channels_pose)
+
+ # Initialize feature extractors:
+ # Relative positional encoder
+ shape_dim = 2 * (self.num_landmarks - 1)
+ shape_encoder = []
+ for step in range(self.steps):
+ shape_encoder.append(RelativePositionEncoder(shape_dim, self.embedded_dim, [256, 256]))
+ self.shape_encoder = nn.ModuleList(shape_encoder)
+ # Diagonal mask used to compute relative positions
+ diagonal_mask = (torch.ones(self.num_landmarks, self.num_landmarks) - torch.eye(self.num_landmarks)).type(torch.bool)
+ self.diagonal_mask = nn.parameter.Parameter(diagonal_mask, requires_grad=False)
+
+ # Visual feature extractor
+ conv_window = []
+ theta_S = []
+ for step in range(self.steps):
+ # S matrix per step
+ WH = self.visual_res # Width/height of ftmap
+ Wout = self.swindow / (2 ** step) * WH # Width/height of the window
+ K = self.kwindow # Kernel or resolution of the window
+ scale = K / WH * (Wout - 1) / (K - 1) # Scale of the affine transformation
+ # Rescale matrix S
+ theta_S_stp = torch.tensor([[scale, 0], [0, scale]])
+ theta_S.append(nn.parameter.Parameter(theta_S_stp, requires_grad=False))
+
+ # Convolutional to embedded to BxLxCx1x1
+ conv_window.append(nn.Conv2d(self.visual_dim, self.embedded_dim, self.kwindow))
+
+ self.theta_S = nn.ParameterList(theta_S)
+ self.conv_window = nn.ModuleList(conv_window)
+
+ # Initialize GAT modules
+ self.gcn = nn.ModuleList([StepRegressor(self.embedded_dim, 256, self.nstack) for i in range(self.steps)])
+
+ def forward(self, data):
+ # Inputs: Visual features and points projections
+ pts_proj, features = self.backbone_forward(data)
+ # Visual field
+ visual_field = features['VisualField'][-1]
+
+ # Params compute only once
+ gat_prob = []
+ features['Landmarks'] = []
+ for step in range(self.steps):
+ # Features generation
+ embedded_ft = self.extract_embedded(pts_proj, visual_field, step)
+
+ # GAT inference
+ offset, gat_prob = self.gcn[step](embedded_ft, gat_prob)
+ offset = F.hardtanh(offset)
+
+ # Update coordinates
+ pts_proj = pts_proj + self.offset_ratio[step] * offset
+ features['Landmarks'].append(pts_proj.clone())
+
+ features['GATProb'] = gat_prob
+ return features
+
+ def backbone_forward(self, data):
+ # Inputs: Image and model3D
+ imgs = data[0]
+ model3d = data[1]
+ cam_matrix = data[2]
+
+ # HourGlass Forward
+ features = self.visual_cnn(imgs)
+
+ # Head pose estimation
+ pose_raw = features['HGcore'][-1]
+ B, L, _, _ = pose_raw.shape
+ pose = pose_raw.reshape(B, L)
+ pose = self.pose_fc(pose)
+ features['Pose'] = pose.clone()
+
+ # Project model 3D
+ euler = pose[:, 0:3]
+ trl = pose[:, 3:]
+ rot = pproj.euler_to_rotation_matrix(euler)
+ pts_proj = pproj.projectPoints(model3d, rot, trl, cam_matrix)
+ pts_proj = pts_proj / self.visual_res
+
+ return pts_proj, features
+
+ def extract_embedded(self, pts_proj, receptive_field, step):
+ # Visual features
+ visual_ft = self.extract_visual_embedded(pts_proj, receptive_field, step)
+ # Shape features
+ shape_ft = self.calculate_distances(pts_proj)
+ shape_ft = self.shape_encoder[step](shape_ft)
+ # Addition
+ embedded_ft = visual_ft + shape_ft
+ return embedded_ft
+
+ def extract_visual_embedded(self, pts_proj, receptive_field, step):
+ # Affine matrix generation
+ B, L, _ = pts_proj.shape # Pts_proj range:[0,1]
+ centers = pts_proj + 0.5 / self.visual_res # BxLx2
+ centers = centers.reshape(B * L, 2) # B*Lx2
+ theta_trl = (-1 + centers * 2).unsqueeze(-1) # BxLx2x1
+ theta_s = self.theta_S[step] # 2x2
+ theta_s = theta_s.repeat(B * L, 1, 1) # B*Lx2x2
+ theta = torch.cat((theta_s, theta_trl), -1) # B*Lx2x3
+
+ # Generate crop grid
+ B, C, _, _ = receptive_field.shape
+ grid = torch.nn.functional.affine_grid(theta, (B * L, C, self.kwindow, self.kwindow))
+ grid = grid.reshape(B, L, self.kwindow, self.kwindow, 2)
+ grid = grid.reshape(B, L, self.kwindow * self.kwindow, 2)
+
+ # Crop windows
+ crops = torch.nn.functional.grid_sample(receptive_field, grid, padding_mode="border") # BxCxLxK*K
+ crops = crops.transpose(1, 2) # BxLxCxK*K
+ crops = crops.reshape(B * L, C, self.kwindow, self.kwindow)
+
+ # Flatten features
+ visual_ft = self.conv_window[step](crops)
+ _, Cout, _, _ = visual_ft.shape
+ visual_ft = visual_ft.reshape(B, L, Cout)
+
+ return visual_ft
+
+ def calculate_distances(self, pts_proj):
+ B, L, _ = pts_proj.shape # BxLx2
+ pts_a = pts_proj.unsqueeze(-2).repeat(1, 1, L, 1)
+ pts_b = pts_a.transpose(1, 2)
+ dist = pts_a - pts_b
+ dist_wo_self = dist[:, self.diagonal_mask, :].reshape(B, L, -1)
+ return dist_wo_self
+
+
+
+
+
+
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..989104829628d28e387946e4bfd816c7feed5e9f
--- /dev/null
+++ b/app.py
@@ -0,0 +1,86 @@
+import os
+import sys
+from pathlib import Path
+import uuid
+import gradio as gr
+try:
+ from spiga.demo.app import video_app
+except:
+ os.system("pip install -e ./SPIGA[demo]")
+ sys.path.append(os.path.abspath("./SPIGA"))
+ from spiga.demo.app import video_app
+
+
+def predict(video_in, image_in_video, image_in_img):
+ if video_in == None and image_in_video == None and image_in_img == None:
+ raise gr.Error("Please upload a video or image.")
+ if image_in_video or image_in_img:
+ print("image", image_in_video, image_in_img)
+ image = image_in_video or image_in_img
+ return image
+ # if video
+
+ video_in = Path(video_in)
+ output_path = Path("/tmp")
+ video_file_name = str(uuid.uuid4())
+ new_video_path = output_path / f"{video_file_name}{video_in.suffix}"
+ video_in.rename(new_video_path)
+
+ video_app(str(new_video_path),
+ # Choices=['wflw', '300wpublic', '300wprivate', 'merlrav']
+ spiga_dataset='wflw',
+ # Choices=['RetinaSort', 'RetinaSort_Res50']
+ tracker='RetinaSort',
+ save=True,
+ output_path=output_path,
+ visualize=False,
+ plot=['fps', 'face_id', 'landmarks', 'headpose'])
+ video_output_path = f"{output_path}/{new_video_path.name[:-4]}.mp4"
+
+ return video_output_path
+
+
+def toggle(choice):
+ if choice == "webcam":
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
+ else:
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
+
+
+with gr.Blocks() as blocks:
+ gr.Markdown("### Video or Image? WebCam or Upload?""")
+ with gr.Tab("Video") as tab:
+ with gr.Row():
+ with gr.Column():
+ video_or_file_opt = gr.Radio(["webcam", "upload"], value="webcam",
+ label="How would you like to upload your video?")
+ video_in = gr.Video(source="webcam", include_audio=False)
+ video_or_file_opt.change(fn=lambda s: gr.update(source=s, value=None), inputs=video_or_file_opt,
+ outputs=video_in, queue=False)
+ with gr.Column():
+ video_out = gr.Video()
+ run_btn = gr.Button("Run")
+ run_btn.click(fn=predict, inputs=[video_in], outputs=[video_out])
+ gr.Examples(fn=predict, examples=[], inputs=[
+ video_in], outputs=[video_out])
+
+ with gr.Tab("Image"):
+ with gr.Row():
+ with gr.Column():
+ image_or_file_opt = gr.Radio(["webcam", "file"], value="webcam",
+ label="How would you like to upload your image?")
+ image_in_video = gr.Image(source="webcam", type="filepath")
+ image_in_img = gr.Image(
+ source="upload", visible=False, type="filepath")
+
+ image_or_file_opt.change(fn=toggle, inputs=[image_or_file_opt],
+ outputs=[image_in_video, image_in_img], queue=False)
+ with gr.Column():
+ image_out = gr.Image()
+ run_btn = gr.Button("Run")
+ run_btn.click(fn=predict, inputs=[
+ image_in_img, image_in_video], outputs=[image_out])
+ gr.Examples(fn=predict, examples=[], inputs=[
+ image_in_img, image_in_video], outputs=[image_out])
+
+blocks.launch()
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a3e85e905383c312f205afe014e8c601ac973a77
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,2 @@
+rustc
+cargo
\ No newline at end of file
diff --git a/pre-requirements.txt b/pre-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e512b5e6f9ff21da877ab7baf6f9a65680818e20
--- /dev/null
+++ b/pre-requirements.txt
@@ -0,0 +1 @@
+pip==23.0.1
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ee7aa7b1b972be9757e3410d2ee0b6816768a1ca
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+gradio
+matplotlib>=3.2.1
+numpy>=1.18.2
+opencv-python>=4.2.0.32
+Pillow>=7.0.0
+torch>=1.4.0
+torchvision>=0.5.0
+torchaudio
+scipy
+scikit-learn
+retinaface-py>=0.0.2
+sort-tracker-py>= 1.0.2
\ No newline at end of file