Learning recurrent representations
for hierarchical behavior modeling


Abstract
We propose a framework for detecting action patterns from motion sequences and modeling the sensory-motor relationship of animals, using a generative recurrent neural network. The network has a discriminative part (classifying actions) and a generative part (predicting motion), whose recurrent cells are laterally connected, allowing higher levels of the network to represent high level phenomena. We test our framework on two types of data, fruit fly behavior and online handwriting. Our results show that 1) taking advantage of unlabeled sequences, by predicting future motion, significantly improves action detection performance when training labels are scarce, 2) the network learns to represent high level phenomena such as writer identity and fly gender, without supervision, and 3) simulated motion trajectories, generated by treating motion prediction as input to the network, look realistic and may be used to qualitatively evaluate whether the model has learnt generative control rules.

Paper: arXiv
Poster: WiML
Data: coming soon
Code: coming soon



Supplementary videos

Handwriting simulation: Text generated one (dx, dy, z) vector at a time (approximately 20 vectors per character), where z denotes stroke visibility.

  


FlyBowl simulation - single agent: The model is initialized with the sequence of motor control and sensory input of a real fly. Once it becomes red it becomes a simulated agent. The plot shown below the arena represents the fly's sensory input; the top represents how it sees other flies (all but the one circled in white which it originated from) and the bottom how it senses the arena boundaries.

  


FlyBowl simulation - multiple agents: One of the videos shows tracks of real fruit flies and the other shows tracks of 20 simulated agents.

     


SynthFly simulation: (left) simulation with our model, (right) simulation with motion regression rather than bin classification. For reference, see ground truth synthetic fly video here. The plot shown below the arena represents the agent's sensory input; the top represents how it sees the obstacle and the bottom how it senses the arena boundaries. The plot below that shows the probability distribution over wing angles (left-green, right-blue) as predicted by the model, from which we sample during simulation.

     




Supplementary figures

Multimodality: (See Section 3.2) If distribution of future motion given a state is multimodal, modeling motion prediction deterministically can result in a prediction belonging to none of the modes observed in training. Here, for example, regression would predict that the fly will go straight. Instead we bin the motion domain and treat motion prediction as a multi-class classification problem for each dimension, maximizing the probability of motions observed taken during training. This results in a probability distribution over possible future motions during inference, which can be sampled from for simulation.

Benefit of diagonal connections: (See Section 5.4) Quantitative measurement of unsupervised character and writer clustering, from 2-dimensional tSNE mapping of hidden states on IAM-OnDB. The top plots show the mean posterior probability of each character (red) and writer (blue), given clustering at states i={1,...,6} (counted from bottom left to bottom right in state visualization below) for our model with (solid) and without (dashed) diagonal connections. P(S), P(S | a) and P(a | S) distributions corresponding to a specific writer in state 4 of model with diagonal are shown for clarification.







Supplementary training details

In order to efficiently train on long temporal sequences containing sparsely distributed actions, we introduced the following options to the training routine:

Maintaining temporal dependency: Recurrent neural networks are in theory able to pass information from the beginning of a sequence to any subsequent time step, but in practice, propagating a loss term several frames backwards is expensive and results in vanishing gradients. This is generally handled by splitting longer sequences into sub-sequences of length s, such that gradients are propagated back by at most s time steps. When processed sequentially, this approach can maintain temporal dependency by initializing hidden states of subsequences with the final state of preceding sub-sequences, however, for computational efficiency it is common practice to process samples in batches and in stochastic order which makes that impossible. One way around this is to set the initial state of each sub sequences to 0, however, doing that breaks the temporal dependency beyond s time steps. Instead, we approximate the true initial state of each sub-sequence be setting it to be the latest computed state for that time step (computed at the previous epoch). The approximate state therefore corresponds to a previous version of the model, but as the model converges so does this approximation. This means that although the loss is only propagated back by s frames, the model can make classifications and predictions based on events that occurred more than s frames ago.

Importance sampling: When analyzing behavior of animals over long durations of time it is often the case that behaviors of interest are temporally sparse, so the datasets tend to be extremely imbalanced with the negative class taking up majority of the frames. To prevent the negative class from dominating the cost, one may choose to train on a subset of the negative samples or on repeated samples from the rare classes, to artificially balance the data. Our approach is similar to this idea, but rather than setting a fixed sample size for each class and randomly selecting samples from it, we dynamically update the class- and sample probabilities after each epoch, such that classes with low recall are sampled more often, and misclassified frames are more likely to be sampled for each class. This allows us to effectively optimize the precision and recall of each class, which cannot be directly done using the cost function as precision depends on the total number of predicted positives.