Commit a2aae2de authored by Konstantinos Papadopoulos's avatar Konstantinos Papadopoulos
Browse files

Initial commit

parents
### Code version (Git Hash) and PyTorch version
### Dataset used
### Expected behavior
### Actual behavior
### Steps to reproduce the behavior
### Other comments
Copyright (c) 2018, Multimedia Laboratary, The Chinese University of Hong Kong
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Spatial Temporal Graph Convolutional Networks (ST-GCN)
A graph convolutional network for skeleton based action recognition.
<div align="center">
<img src="resource/info/pipeline.png">
</div>
This repository holds the codebase, dataset and models for the paper>
**Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition** Sijie Yan, Yuanjun Xiong and Dahua Lin, AAAI 2018.
[[Arxiv Preprint]](https://arxiv.org/abs/1801.07455)
## News & Updates
- Feb. 21, 2019 - We provide pretrained models and training scripts on **NTU-RGB+D** and **kinetics-skeleton** datasets. So that you can achieve the performance we mentioned in the paper.
- June. 5, 2018 - A demo for feature visualization and skeleton based action recognition is released.
- June. 1, 2018 - We update our code base and complete the PyTorch 0.4.0 migration.
## Visulization of ST-GCN in Action
Our demo for skeleton based action recognition:
<p align="center">
<img src="resource/info/demo_video.gif", width="1200">
</p>
ST-GCN is able to exploit local pattern and correlation from human skeletons.
Below figures show the neural response magnitude of each node in the last layer of our ST-GCN.
<table style="width:100%; table-layout:fixed;">
<tr>
<td><img width="150px" src="resource/info/S001C001P001R001A044_w.gif"></td>
<td><img width="150px" src="resource/info/S003C001P008R001A008_w.gif"></td>
<td><img width="150px" src="resource/info/S002C001P010R001A017_w.gif"></td>
<td><img width="150px" src="resource/info/S003C001P008R001A002_w.gif"></td>
<td><img width="150px" src="resource/info/S001C001P001R001A051_w.gif"></td>
</tr>
<tr>
<td><font size="1">Touch head<font></td>
<td><font size="1">Sitting down<font></td>
<td><font size="1">Take off a shoe<font></td>
<td><font size="1">Eat meal/snack<font></td>
<td><font size="1">Kick other person<font></td>
</tr>
<tr>
<td><img width="150px" src="resource/info/hammer_throw_w.gif"></td>
<td><img width="150px" src="resource/info/clean_and_jerk_w.gif"></td>
<td><img width="150px" src="resource/info/pull_ups_w.gif"></td>
<td><img width="150px" src="resource/info/tai_chi_w.gif"></td>
<td><img width="150px" src="resource/info/juggling_balls_w.gif"></td>
</tr>
<tr>
<td><font size="1">Hammer throw<font></td>
<td><font size="1">Clean and jerk<font></td>
<td><font size="1">Pull ups<font></td>
<td><font size="1">Tai chi<font></td>
<td><font size="1">Juggling ball<font></td>
</tr>
</table>
The first row of above results is from **NTU-RGB+D** dataset, and the second row is from **Kinetics-skeleton**.
## Prerequisites
Our codebase is based on **Python3** (>=3.5). There are a few dependencies to run the code. The major libraries we depend are
- [PyTorch](http://pytorch.org/) (Release version 0.4.0)
- [Openpose@92cdcad](https://github.com/yysijie/openpose) (Optional: for demo only)
- FFmpeg (Optional: for demo only), which can be installed by `sudo apt-get install ffmpeg`
- Other Python libraries can be installed by `pip install -r requirements.txt`
### Installation
```
cd torchlight; python setup.py install; cd ..
```
### Get pretrained models
We provided the pretrained model weithts of our **ST-GCN**. The model weights can be downloaded by running the script
```
bash tools/get_models.sh
```
<!-- The downloaded models will be stored under ```./models```. -->
You can also obtain models from [GoogleDrive](https://drive.google.com/drive/folders/1IYKoSrjeI3yYJ9bO0_z_eDo92i7ob_aF) or [BaiduYun](https://pan.baidu.com/s/1dwKG2TLvG-R1qeIiE4MjeA#list/path=%2FShare%2FAAAI18%2Fst-gcn%2Fmodels&parentPath=%2FShare), and manually put them into ```./models```.
## Demo
To visualize how ST-GCN exploit local correlation and local pattern, we compute the feature vector magnitude of each node in the final spatial temporal graph, and overlay them on the original video. **Openpose** should be ready for extracting human skeletons from videos. The skeleton based action recognition results is also shwon thereon.
Run the demo by this command:
```
python main.py demo --openpose <path to openpose build directory> [--video <path to your video> --device <gpu0> <gpu1>]
```
A video as above will be generated and saved under ```data/demo_result/```.
## Data Preparation
We experimented on two skeleton-based action recognition datasts: **Kinetics-skeleton** and **NTU RGB+D**.
### Kinetics-skeleton
[Kinetics](https://deepmind.com/research/open-source/open-source-datasets/kinetics/) is a video-based dataset for action recognition which only provide raw video clips without skeleton data. Kinetics dataset include To obatin the joint locations, we first resized all videos to the resolution of 340x256 and converted the frame rate to 30 fps. Then, we extracted skeletons from each frame in Kinetics by [Openpose](https://github.com/CMU-Perceptual-Computing-Lab/openpose). The extracted skeleton data we called **Kinetics-skeleton**(7.5GB) can be directly downloaded from [GoogleDrive](https://drive.google.com/open?id=1SPQ6FmFsjGg3f59uCWfdUWI-5HJM_YhZ) or [BaiduYun](https://pan.baidu.com/s/1dwKG2TLvG-R1qeIiE4MjeA#list/path=%2FShare%2FAAAI18%2Fkinetics-skeleton&parentPath=%2FShare).
After uncompressing, rebuild the database by this command:
```
python tools/kinetics_gendata.py --data_path <path to kinetics-skeleton>
```
### NTU RGB+D
NTU RGB+D can be downloaded from [their website](http://rose1.ntu.edu.sg/datasets/actionrecognition.asp).
Only the **3D skeletons**(5.8GB) modality is required in our experiments. After that, this command should be used to build the database for training or evaluation:
```
python tools/ntu_gendata.py --data_path <path to nturgbd+d_skeletons>
```
where the ```<path to nturgbd+d_skeletons>``` points to the 3D skeletons modality of NTU RGB+D dataset you download.
## Testing Pretrained Models
<!-- ### Evaluation
Once datasets ready, we can start the evaluation. -->
To evaluate ST-GCN model pretrained on **Kinetcis-skeleton**, run
```
python main.py recognition -c config/st_gcn/kinetics-skeleton/test.yaml
```
For **cross-view** evaluation in **NTU RGB+D**, run
```
python main.py recognition -c config/st_gcn/ntu-xview/test.yaml
```
For **cross-subject** evaluation in **NTU RGB+D**, run
```
python main.py recognition -c config/st_gcn/ntu-xsub/test.yaml
```
<!-- Similary, the configuration file for testing baseline models can be found under the ```./config/baseline```. -->
To speed up evaluation by multi-gpu inference or modify batch size for reducing the memory cost, set ```--test_batch_size``` and ```--device``` like:
```
python main.py recognition -c <config file> --test_batch_size <batch size> --device <gpu0> <gpu1> ...
```
### Results
The expected **Top-1** **accuracy** of provided models are shown here:
| Model| Kinetics-<br>skeleton (%)|NTU RGB+D <br> Cross View (%) |NTU RGB+D <br> Cross Subject (%) |
| :------| :------: | :------: | :------: |
|Baseline[1]| 20.3 | 83.1 | 74.3 |
|**ST-GCN** (Ours)| **31.6**| **88.8** | **81.6** |
[1] Kim, T. S., and Reiter, A. 2017. Interpretable 3d human action analysis with temporal convolutional networks. In BNMW CVPRW.
## Training
To train a new ST-GCN model, run
```
python main.py recognition -c config/st_gcn/<dataset>/train.yaml [--work_dir <work folder>]
```
where the ```<dataset>``` must be ```ntu-xsub```, ```ntu-xview``` or ```kinetics-skeleton```, depending on the dataset you want to use.
The training results, including **model weights**, configurations and logging files, will be saved under the ```./work_dir``` by default or ```<work folder>``` if you appoint it.
You can modify the training parameters such as ```work_dir```, ```batch_size```, ```step```, ```base_lr``` and ```device``` in the command line or configuration files. The order of priority is: command line > config file > default parameter. For more information, use ```main.py -h```.
Finally, custom model evaluation can be achieved by this command as we mentioned above:
```
python main.py recognition -c config/st_gcn/<dataset>/test.yaml --weights <path to model weights>
```
## Citation
Please cite the following paper if you use this repository in your reseach.
```
@inproceedings{stgcn2018aaai,
title = {Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition},
author = {Sijie Yan and Yuanjun Xiong and Dahua Lin},
booktitle = {AAAI},
year = {2018},
}
```
## Contact
For any question, feel free to contact
```
Sijie Yan : ys016@ie.cuhk.edu.hk
Yuanjun Xiong : bitxiong@gmail.com
```
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Training script')
# General arguments
parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset') # h36m or humaneva
parser.add_argument('-k', '--keypoints', default='cpn_ft_h36m_dbb', type=str, metavar='NAME', help='2D detections to use')
parser.add_argument('-str', '--subjects-train', default='S1,S5,S6,S7,S8', type=str, metavar='LIST',
help='training subjects separated by comma')
parser.add_argument('-ste', '--subjects-test', default='S9,S11', type=str, metavar='LIST', help='test subjects separated by comma')
parser.add_argument('-sun', '--subjects-unlabeled', default='', type=str, metavar='LIST',
help='unlabeled subjects separated by comma for self-supervision')
parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST',
help='actions to train/test on, separated by comma, or * for all')
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
help='checkpoint directory')
parser.add_argument('--checkpoint-frequency', default=10, type=int, metavar='N',
help='create a checkpoint every N epochs')
parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME',
help='checkpoint to resume (file name)')
parser.add_argument('--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
parser.add_argument('--render', action='store_true', help='visualize a particular video')
parser.add_argument('--by-subject', action='store_true', help='break down error by subject (on evaluation)')
parser.add_argument('--export-training-curves', action='store_true', help='save training curves as .png images')
# Model arguments
parser.add_argument('-s', '--stride', default=1, type=int, metavar='N', help='chunk size to use during training')
parser.add_argument('-e', '--epochs', default=60, type=int, metavar='N', help='number of training epochs')
parser.add_argument('-b', '--batch-size', default=1024, type=int, metavar='N', help='batch size in terms of predicted frames')
parser.add_argument('-drop', '--dropout', default=0.25, type=float, metavar='P', help='dropout probability')
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('-lrd', '--lr-decay', default=0.95, type=float, metavar='LR', help='learning rate decay per epoch')
parser.add_argument('-no-da', '--no-data-augmentation', dest='data_augmentation', action='store_false',
help='disable train-time flipping')
parser.add_argument('-no-tta', '--no-test-time-augmentation', dest='test_time_augmentation', action='store_false',
help='disable test-time flipping')
parser.add_argument('-arc', '--architecture', default='3,3,3', type=str, metavar='LAYERS', help='filter widths separated by comma')
parser.add_argument('--causal', action='store_true', help='use causal convolutions for real-time processing')
parser.add_argument('-ch', '--channels', default=1024, type=int, metavar='N', help='number of channels in convolution layers')
# Experimental
parser.add_argument('--subset', default=1, type=float, metavar='FRACTION', help='reduce dataset size by fraction')
parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor (semi-supervised)')
parser.add_argument('--warmup', default=1, type=int, metavar='N', help='warm-up epochs for semi-supervision')
parser.add_argument('--no-eval', action='store_true', help='disable epoch evaluation while training (small speed-up)')
parser.add_argument('--dense', action='store_true', help='use dense convolutions instead of dilated convolutions')
parser.add_argument('--disable-optimizations', action='store_true', help='disable optimized model for single-frame predictions')
parser.add_argument('--linear-projection', action='store_true', help='use only linear coefficients for semi-supervised projection')
parser.add_argument('--no-bone-length', action='store_false', dest='bone_length_term',
help='disable bone length term in semi-supervised settings')
parser.add_argument('--no-proj', action='store_true', help='disable projection for semi-supervised setting')
# Visualization
parser.add_argument('--viz-subject', type=str, metavar='STR', help='subject to render')
parser.add_argument('--viz-action', type=str, metavar='STR', help='action to render')
parser.add_argument('--viz-camera', type=int, default=0, metavar='N', help='camera to render')
parser.add_argument('--viz-video', type=str, metavar='PATH', help='path to input video')
parser.add_argument('--viz-skip', type=int, default=0, metavar='N', help='skip first N frames of input video')
parser.add_argument('--viz-output', type=str, metavar='PATH', help='output file name (.gif or .mp4)')
parser.add_argument('--viz-bitrate', type=int, default=3000, metavar='N', help='bitrate for mp4 videos')
parser.add_argument('--viz-no-ground-truth', action='store_true', help='do not show ground-truth poses')
parser.add_argument('--viz-limit', type=int, default=-1, metavar='N', help='only render first N frames')
parser.add_argument('--viz-downsample', type=int, default=1, metavar='N', help='downsample FPS by a factor N')
parser.add_argument('--viz-size', type=int, default=5, metavar='N', help='image size')
parser.set_defaults(bone_length_term=True)
parser.set_defaults(data_augmentation=True)
parser.set_defaults(test_time_augmentation=True)
args = parser.parse_args()
# Check invalid configuration
if args.resume and args.evaluate:
print('Invalid flags: --resume and --evaluate cannot be set at the same time')
exit()
if args.export_training_curves and args.no_eval:
print('Invalid flags: --export-training-curves and --no-eval cannot be set at the same time')
exit()
return args
\ No newline at end of file
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import numpy as np
import torch
from common.utils import wrap
from common.quaternion import qrot, qinverse
def normalize_screen_coordinates(X, w, h):
assert X.shape[-1] == 2
# Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
return X/w*2 - [1, h/w]
def image_coordinates(X, w, h):
assert X.shape[-1] == 2
# Reverse camera frame normalization
return (X + [1, h/w])*w/2
def world_to_camera(X, R, t):
Rt = wrap(qinverse, R) # Invert rotation
return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate
def camera_to_world(X, R, t):
return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t
def project_to_2d(X, camera_params):
"""
Project 3D points to 2D using the Human3.6M camera projection function.
This is a differentiable and batched reimplementation of the original MATLAB script.
Arguments:
X -- 3D points in *camera space* to transform (N, *, 3)
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
"""
assert X.shape[-1] == 3
assert len(camera_params.shape) == 2
assert camera_params.shape[-1] == 9
assert X.shape[0] == camera_params.shape[0]
while len(camera_params.shape) < len(X.shape):
camera_params = camera_params.unsqueeze(1)
f = camera_params[..., :2]
c = camera_params[..., 2:4]
k = camera_params[..., 4:7]
p = camera_params[..., 7:]
XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True)
radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True)
tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True)
XXX = XX*(radial + tan) + p*r2
return f*XXX + c
def project_to_2d_linear(X, camera_params):
"""
Project 3D points to 2D using only linear parameters (focal length and principal point).
Arguments:
X -- 3D points in *camera space* to transform (N, *, 3)
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
"""
assert X.shape[-1] == 3
assert len(camera_params.shape) == 2
assert camera_params.shape[-1] == 9
assert X.shape[0] == camera_params.shape[0]
while len(camera_params.shape) < len(X.shape):
camera_params = camera_params.unsqueeze(1)
f = camera_params[..., :2]
c = camera_params[..., 2:4]
XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
return f*XX + c
\ No newline at end of file
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from itertools import zip_longest
import numpy as np
class ChunkedGenerator:
"""
Batched data generator, used for training.
The sequences are split into equal-length chunks and padded as necessary.
Arguments:
batch_size -- the batch size to use for training
cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
poses_2d -- list of input 2D keypoints, one element for each video
chunk_length -- number of output frames to predict for each training example (usually 1)
pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
shuffle -- randomly shuffle the dataset before each epoch
random_seed -- initial seed to use for the random generator
augment -- augment the dataset by flipping poses horizontally
kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
"""
def __init__(self, batch_size, cameras, poses_3d, poses_2d,
chunk_length, pad=0, causal_shift=0,
shuffle=True, random_seed=1234,
augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None,
endless=False):
assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d))
assert cameras is None or len(cameras) == len(poses_2d)
# Build lineage info
pairs = [] # (seq_idx, start_frame, end_frame, flip) tuples
for i in range(len(poses_2d)):
assert poses_3d is None or poses_3d[i].shape[0] == poses_3d[i].shape[0]
n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length
offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2
bounds = np.arange(n_chunks+1)*chunk_length - offset
augment_vector = np.full(len(bounds - 1), False, dtype=bool)
pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector)
if augment:
pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector)
# Initialize buffers
if cameras is not None:
self.batch_cam = np.empty((batch_size, cameras[0].shape[-1]))
if poses_3d is not None:
self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
self.num_batches = (len(pairs) + batch_size - 1) // batch_size
self.batch_size = batch_size
self.random = np.random.RandomState(random_seed)
self.pairs = pairs
self.shuffle = shuffle
self.pad = pad
self.causal_shift = causal_shift
self.endless = endless
self.state = None
self.cameras = cameras
self.poses_3d = poses_3d
self.poses_2d = poses_2d
self.augment = augment
self.kps_left = kps_left
self.kps_right = kps_right
self.joints_left = joints_left
self.joints_right = joints_right
def num_frames(self):
return self.num_batches * self.batch_size
def random_state(self):
return self.random
def set_random_state(self, random):
self.random = random
def augment_enabled(self):
return self.augment
def next_pairs(self):
if self.state is None:
if self.shuffle:
pairs = self.random.permutation(self.pairs)
else:
pairs = self.pairs
return 0, pairs
else:
return self.state
def next_epoch(self):
enabled = True
while enabled:
start_idx, pairs = self.next_pairs()
for b_i in range(start_idx, self.num_batches):
chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size]
for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks):
start_2d = start_3d - self.pad - self.causal_shift
end_2d = end_3d + self.pad - self.causal_shift
# 2D poses
seq_2d = self.poses_2d[seq_i]
low_2d = max(start_2d, 0)
high_2d = min(end_2d, seq_2d.shape[0])
pad_left_2d = low_2d - start_2d
pad_right_2d = end_2d - high_2d
if pad_left_2d != 0 or pad_right_2d != 0:
self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
else:
self.batch_2d[i] = seq_2d[low_2d:high_2d]
if flip:
# Flip 2D keypoints
self.batch_2d[i, :, :, 0] *= -1
self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left]
# 3D poses
if self.poses_3d is not None:
seq_3d = self.poses_3d[seq_i]
low_3d = max(start_3d, 0)
high_3d = min(end_3d, seq_3d.shape[0])
pad_left_3d = low_3d - start_3d
pad_right_3d = end_3d - high_3d
if pad_left_3d != 0 or pad_right_3d != 0:
self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge')
else:
self.batch_3d[i] = seq_3d[low_3d:high_3d]
if flip:
# Flip 3D joints
self.batch_3d[i, :, :, 0] *= -1
self.batch_3d[i, :, self.joints_left + self.joints_right] = \
self.batch_3d[i, :, self.joints_right + self.joints_left]
# Cameras
if self.cameras is not None:
self.batch_cam[i] = self.cameras[seq_i]
if flip:
# Flip horizontal distortion coefficients
self.batch_cam[i, 2] *= -1
self.batch_cam[i, 7] *= -1
if self.endless:
self.state = (b_i + 1, pairs)
if self.poses_3d is None and self.cameras is None:
yield None, None, self.batch_2d[:len(chunks)]
elif self.poses_3d is not None and self.cameras is None:
yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
elif self.poses_3d is None:
yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)]
else:
yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
if self.endless:
self.state = None
else:
enabled = False
class UnchunkedGenerator:
"""
Non-batched data generator, used for testing.
Sequences are returned one at a time (i.e. batch size = 1), without chunking.
If data augmentation is enabled, the batches contain two sequences (i.e. batch size = 2),
the second of which is a mirrored version of the first.
Arguments:
cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
poses_2d -- list of input 2D keypoints, one element for each video
pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
augment -- augment the dataset by flipping poses horizontally
kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
"""
def __init__(self, cameras, poses_3d, poses_2d, pad=0, causal_shift=0,
augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None):
assert poses_3d is None or len(poses_3d) == len(poses_2d)
assert cameras is None or len(cameras) == len(poses_2d)
self.augment = augment
self.kps_left = kps_left
self.kps_right = kps_right
self.joints_left = joints_left
self.joints_right = joints_right
self.pad = pad
self.causal_shift = causal_shift
self.cameras = [] if cameras is None else cameras
self.poses_3d = [] if poses_3d is None else poses_3d
self.poses_2d = poses_2d