Wednesday, March 7, 2018

Recurrent Neural Network

Introduction


For normal FNN, the elements in the input vector $\mathbf{x}$ has no intrinsic order.  Therefore, if one uses FNN to solve a computer vision problems, it will not be able to take advantage of the texture patterns formed by neighboring pixels, as the "neighboring" relationship is lost.  CNN instead takes advantage of the spatial relationship of a neighborhood and use that to form features, therefore, it is superior for computer vision problems.  Similarly, when the input is a sequence data, which could be a temporal time series or positional data such as a text stream, we do not want to scramble the order of the data points, instead should try to recognize sequence patterns, which requires a different architecture than FNN.  You may argue why not use CNN to recognize the sequence pattern.  CNN has two limitations, first, CNN can only look at patterns of a fixed size (determined by the size of its filters), but sometimes we are interested in long-range patterns.  For instance, to predict weather, no longer we have to look for patterns within the past few days, but few weeks, few months, even few years, such long-range pattern cannot be captured well by local CNN filters.  Second, CNN patterns are translational invariant, which means they are not context sensitive.  Patterns embedded in sequence data are often interpreted differently, depending on its context.  For example, the word "Chinese" within a document could mean either Chinese people or Chinese language, its exact meaning depends on the context.  CNN does not capture context.

We so far have been dealing with inputs of fixed size.  This is true for all the traditional machine learning methods we discussed previously (such as SVM, random forest), as well as for CNN.  Therefore to classify ImageNet pictures, we need to make sure they are all of the same dimensions.  In reality, many inputs are of variable length. For example, to transcribe voice, the voice wave form lasts differently depending on what the contents are and the rate of speech.  When XBox Kinect recognizes a kick action, the whole kick video sequence comes at varying number of frames, as there are quick kicks and slow kicks.  To read an email and then classify it into spam or ham, the length of the emails vary.  Traditionally for the XBox Kinect problem, we apply Bayesian graph theory to construct generative models to explain the video sequence, which is very complicated and heavily rely on oversimplified assumptions.  Therefore, we need a new NN architecture to handle variable input length, as well as mechanisms to discover patterns of varying time spans.

When input elements are fed into our machine one piece at a time, the machine must have memory mechanism.  If the machine makes a decision simply based on the current input token, word "Chinese" will always be associated with either a language or people, preventing the true meaning from being assigned based on its context.  In translation, the output sequence should also be variable.  The new architecture invented to overcome the challenge of variable-length input/output sequence is called recurrent neural network (RNN).  In this blog, we try to use simple examples to convince you RNN theoretically is capable of capturing patterns in a sequence of arbitrary length.  RNN has applications in language translation, video classification, photo captioning, trend prediction (arguably stock price prediction, not sure if that can work in theory.  See Table 4 and conclusion in [1]), etc.


Architecture


The basic unit of RNN is a cell as depicted in Figure 1.  A cell has an internal state, which is a vector $\mathbf{h}$, so a cell can store multiple internal elements.  At time $t$, it reads in the input $\mathbf{x}_{t}$, mixed with its current hidden status $\mathbf{h}_{t-1}$ to transition into a new state $\mathbf{h}_{t}$; $\mathbf{h}_{t}$ can then be used to produce an output vector $\mathbf{y}_{t}$.  Imagine $\mathbf{h}_{0}$ is initialized as a zero vector, the cell is able to read in input $\mathbf{x}$, one at a time, and produces a sequence of hidden states $[\mathbf{h}_1,\mathbf{h}_2, \ldots, \mathbf{h}_n]$, for our purpose, we can assume $\mathbf{y} = \mathbf{h}$ for now.

Figure 1. An RNN cell rolled out in time sequence [2].
The update math can be described as the following:

$$\begin{equation}
\mathbf{h}_{t} = \phi(\mathbf{h}_{t-1}\mathbf{W}_{h}+ \mathbf{x}_{t}\mathbf{W}_x + \mathbf{b}),
\end{equation}$$

where $\mathbf{W}_x$, $\mathbf{W}_h$ and $\mathbf{b}$ are cell parameters to be determined by optimization.  $\mathbf{h}_{t}$ is the hidden state of the cell, i.e., it is the impression of all the data the RNN cell has viewed so far: $[\mathbf{x}_{1}, \mathbf{x}_{2}, \ldots, \mathbf{x}_{t}]$.  After we watched a movie, the snapshot of our brain is $\mathbf{h}_t$, which contains all the effects of the whole movie.  This vector is then used as the input to classify the movie into a good one or a bad one.

Now let us try to understand why this simple recurrent architecture works.  I have not read about any "proofs" as elegant as what we did for FNN in a previous blog, where it is rather convincing to see why FNN can model any arbitrary function.  For RNN, the best I can do is to present some special cases to convince you RNN is indeed quite capable.

Memory


RNN cells can memorize data.  Imagine the input is the time series of daily stock price, and output $y$ is the most-recent three day average, or any prediction of tomorrow's price based on the past three days.  In the prediction at time $t$, we only receive a one-day input $\mathbf{x}_{t}$, thus there needs to be a way for CNN to retain $\mathbf{x}_{t-1}$ and $\mathbf{x}_{t-2}$, i.e., it should have short-term memory capability in order to use three pieces of data to generate the output $y_t$.

Let us look at the special case, where stock price $x_t$ is a scalar:
$$\begin{equation} \mathbf{h}_0 = {[ 0, \; 0, \; 0]}^\intercal, \\ \mathbf{h}_t = \phi(\mathbf{h}_{t-1}\mathbf{W}_{h}+x_{t}\mathbf{W}_x +\mathbf{b}), \\ \phi(x) = x, \\ \mathbf{W}_h = \begin{bmatrix} 0 & 1 & 0 \\ 0 & 0 & 1 \\ 0 & 0 & 0 \end{bmatrix}, \\ \mathbf{W}_x = {[ 0, \; 0, \; 1 ]}^\intercal, \\ \mathbf{b} = {[ 0, \; 0, \; 0 ]}^\intercal. \\ \end{equation}$$
Our hidden state contains three elements, reserved to store the stock prices for the past three days.  We follow the time sequence, the following $\mathbf{h}$ sequence can be obtained based on Equation 2:

$$\begin{equation} \mathbf{h} = \left\{ \mathbf{h}_0=\begin{bmatrix} 0 \\ 0 \\ 0 \end{bmatrix}, \mathbf{h}_1=\begin{bmatrix} 0 \\ 0 \\ x_1 \end{bmatrix}, \mathbf{h}_2=\begin{bmatrix} 0 \\ x_1 \\ x_2 \end{bmatrix}, \mathbf{h}_3=\begin{bmatrix} x_1 \\ x_2 \\ x_3 \end{bmatrix}, \\ \mathbf{h}_4=\begin{bmatrix} x_2 \\ x_3 \\ x_4 \end{bmatrix}, \ldots, \mathbf{h}_t=\begin{bmatrix} x_{t-2} \\ x_{t-1} \\ x_t \end{bmatrix}, \ldots \right\}. \end{equation}$$
This way RNN can retain a short-term memory.  Although at each time point, we could use all stock prices for the past three days as the feature vector input in real applications, we here choose a scalar input form to illustrate the memory mechanism.

Running Average and Derivative


Let us consider another special case, where our hidden state has three elements, reserved to store the running average $h_{\Sigma}$, running derivative $h_{\nabla}$, and memorize the previous input $h_{mem}$, respectively:

$$\begin{equation} \mathbf{h}= {[ h_{\Sigma}, \; h_{\nabla}, \; h_{mem} ]}^\intercal, \\
\mathbf{h}_0 = {[ 0, \; 0, \; 0]}^\intercal, \\ \mathbf{h}_t = \phi(\mathbf{h}_{t-1}\mathbf{W}_{h}+x_{t}\mathbf{W}_x +\mathbf{b}), \\ \phi(x) = x, \\ \mathbf{W}_h = \begin{bmatrix} \beta_1 & 0 & 0 \\ 0 & \beta_2 & -(1-\beta_2) \\ 0 & 0 & 0 \end{bmatrix}, \\ \mathbf{W}_x = {[ 1-\beta_1, \; 1-\beta_2, \; 1 ]}^\intercal, \\ \mathbf{b} = {[ 0, \; 0, \; 0 ]}^\intercal. \\ \end{equation}$$
The recurrent equation can be simplified into:
$$\begin{equation}
\begin{bmatrix} h_{\Sigma,t} \\ h_{\nabla,t} \\ h_{mem,t} \end{bmatrix} =
\begin{bmatrix} \beta_1 h_{\Sigma,t-1} + (1-\beta_1)x_t \\
\beta_2 h_{\nabla,t-1} + (1-\beta_2) (x_t - x_{t-1}) \\ x_t \end{bmatrix}.
\end{equation}$$

We follow the time sequence, we obtain the following $\mathbf{h}$ sequence:

$$\begin{equation} \mathbf{h} = \left\{ \mathbf{h}_0=\begin{bmatrix} 0 \\ 0 \\ 0 \end{bmatrix}, \mathbf{h}_1=\begin{bmatrix} (1-\beta_1)x_1 \\ (1-\beta_2)x_1 \\ x_1 \end{bmatrix}, \\
\mathbf{h}_2=\begin{bmatrix} (1-\beta_1)\left(\beta_1  x_1+  x_2 \right) \\ (1-\beta_2)\left( \beta_2(x_1-0)+ (x_2-x_1)\right) \\ x_2 \end{bmatrix}, \ldots, \\
\mathbf{h}_t=\begin{bmatrix} (1-\beta_1)\left(\beta_1^{t-1}x_1+\beta_1^{t-2}x_2+\cdots + x_t)\right) \\  (1-\beta_2)\left(\beta_2^{t-1}(x_1-0)+\beta_2^{t-2}(x_2-x_1)+ \cdots + (x_t-x_{t-1})\right)  \\ x_t \end{bmatrix}\right\}. \end{equation}$$
Now we can see how the three elements in our hidden state work.  $h_{mem}$ simply memorizes the current input, so that $x_{t-1}$ is available to be used with $x_t$ at time point $t$.  $h_{\Sigma}$ is for calculating a running average, which is similar to the momentum optimization.  The current average is first decayed by a factor $\beta_1$, the new input is weighted by $(1-\beta_1)$ and used to update the running average.  $h_{\nabla}$ is for calculating a running first-order derivative.  The current derivative is first decayed by a factor $\beta_2$, the new derivative estimated based on $(x_{t}-x_{t-1})$ is weighted by $(1-\beta_2)$ and used to update the derivate memory.   We here assume $\beta_1$ and $\beta_2$ are decay rate within $[0, 1]$.  When $\beta$ approaches one, we see data far into the past, the new input has little weight to change the average or derivate too much.  When $\beta$ approaches one, we only see the effect of very recent inputs.  Imagine we create many RNN cells with different $beta$ values, we can have a range of cells that capture the temporal average or derivative of varying time windows.  If you stack such memory cells, say one $h_{\nabla}^{(2)}$ on top of the output of another $h_{\nabla}^{(1)}$, we can get higher-order derivatives with a multi-layer RNN architecture (Deep RNN).

Above thought experiments are far from any proof, they nevertheless offer us some confidence that with enough cells and enough layers, an RNN is able to capture many temporal patterns.  In fact, it was believed that RNN can be used to simulate arbitrary computer program [3]:

Thus RNNs can be thought of as essentially describing computer programs.  In fact, it has been shown that RNNs are turing complete (for more information refer to the article: On the Computational Power of Neural Nets, by H. T. Siegelmann and E. D. Sontag, proceeding of the fifth annual workshop on computational learning theory, ACM, 1992.) in the sense that given the proper weights, they can simulate arbitrary programs.

When the activation of its hidden state can encode different temporal patterns (just as feature neuron in CNN), they can certainly be used as powerful features to for downstream predictions.  If our thought experiment is representative, the memory exercise tells us we would need an RNN cell to contain at least $m$ hidden elements, if we want to predict the future based on the past $m$ time points.  Therefore, RNN cells should have reasonable complexity to be useful.  The second exercise tells us activities captured by a cell contains an exponential term on $\beta^t$.  During optimization, our $\beta$ may become greater than one, which will result in an explosion and make optimization impossible, as early steps make infinite contribution.  When $\beta$ becomes too small, the early steps have virtually no contribution, this is called vanishing gradient problem, which prevents RNN from looking far into the past.  Therefore this RNN architecture is actually very hard to train.   Also, it is inefficient if we would need large number of memory elements, when there is a need to use stock price in the past few years to predict tomorrow's price.  In practice this classical RNN architecture is now called SimpleRNN, the architecture we discussed above has been replaced by more efficient alternatives, such as LSTM (long short term memory) to be introduced later.


RNN Applications


RNN has a few variants that make it suitable to handle a number of very interesting use cases.

Case A, each input produces an output.  Say stock price prediction, each new input $x_t$ of today's price will enable us to combine it with the data it has seen so far (coded in $\mathbf{h}_{t-1}$) to predict tomorrow's price.  We can conveniently ignore the first few outputs as they were based on very few inputs, and start using the prediction only after it has seen enough historical data points.

Case B, all outputs are ignored, only the last output is used.  An RNN reads a movie review from beginning to the end, its last hidden state $\mathbf{h}_{last}$ encodes content for the whole review, therefore, this can be used as the feature representation of the review and be passed onto a FNN to classify the review into positive/negative or to be regressed into movie ratings.

Case C, only one input and the RNN generates a sequence of output.  Give the RNN a photo $\mathbf{x}_0$, RNN represents the photo into a hidden state, then this hidden state can recurrently output a sequence of words to be used as the capture.  This sounds magic, but it actually has been shown to work reasonably well, please watch the nice TED talk by Feifei Li [4], if you have not.

Case D, RNN takes an input sequence, then it outputs a new sequence.  This is what behind language translocation, where an English sentence is sent to an RNN, its hidden state at the end encodes the meaning of the sentence, it then spits out a new sentence in Chinese.  Case D is actually the combination of Case B and Case C.

Figure 2. Four different RNN use cases [5].


LSTM


We mentioned above that the SimpleRNN architecture has many practical issues, especially it cannot accommodate too many time steps.  When there are many time steps, the network suffers from either vanishing or exploding gradient problem, so RNN cannot have longer term memory.  LSTM is a technique that addresses these challenges by enabling longer short-term memories, so it is called Long Short-Term Memory cells.

Figure 3. Anatomy of an LSTM cell [6].
Figure 4. The update rules of LSTM [7].
LSTM cell is still quite a magic to me, so my confidence is not very high for what I write next.

The example I am using is inspired by a tool called LSTMVis [8].  Let us consider an LSTM cell to parse the following string, where numeric tokens are associated with different depth defined by parentheses (highlighted in three colors, one per depth), and the ultimate goal of the LSTM is to sum up all numbers associated with depth two and beyond (blue and green numbers).

( 4.5 ( 2.5 ( 3.5 5.5 ) 3.2 ( 1.0 ) ) 2.86.8 9.0 ) )

The encounter of each left parenthesis increases the depth by one and a right parenthesis decreases the depth by one.  Our cell vector $\mathbf{c}$ naturally contains two elements, one for counting the depth and another for storing the sum, i.e., ${[ c_{depth}, c_{sum} ]}^\intercal$.  Our input vector $\mathbf{x}$ has four elements: ${[ x_l, x_r, x_n, x_v ]}^\intercal$, where $x_1$, $x_r$ and $x_n$ are one-hot encoder indicating the input is a left parenthesis, a right parenthesis, or a number, respectively.  When it is a number, $x_v$ stores the numerical input value.  Therefore string "(2.5 4.2)" is represented as four inputs:

$$\begin{equation} \begin{bmatrix}1\\0\\0\\0\end{bmatrix},  \begin{bmatrix}0\\0\\1\\2.5\end{bmatrix}, \begin{bmatrix}0\\0\\1\\4.2\end{bmatrix}, \begin{bmatrix}0\\1\\0\\0\end{bmatrix}. \end{equation}$$

Here we only look at the special case where output gate is always open, i.e., $\mathbf{o} = {[1, 1]}^\intercal$, and output state $\mathbf{h} = \mathbf{c}$.

Let us first consider the case where there is no feedback loop, i.e., $\mathbf{h}_{t-1}$ is not linked to the input at time $t$.  The input gate is determined by:
$$\begin{equation} \mathbf{i} = \rm{sign} \left(\begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix}\cdot \mathbf{x}\right). \end{equation}$$
Apply Equation 8 to Equation 7, we obtain a sequence of input gates:

$$\begin{equation} \begin{bmatrix}1\\0\end{bmatrix},  \begin{bmatrix}0\\1\end{bmatrix}, \begin{bmatrix}0\\1\end{bmatrix}, \begin{bmatrix}1\\0\end{bmatrix}. \end{equation}$$

The input gate is ${[1, 0]}^\intercal$, when it sees either "(" or ")", therefore, the input gate is open for $c_{depth}$ in order for it to be updated and the input gate for $c_{sum}$ is closed as no update on that element is needed.  We here replace $\rm{tahn}(\cdot)$ by a simple $\rm{sign}(\cdot)$ function :
$$\begin{equation} \rm{sign}(x) = \begin{cases} 1, \; \rm{if} \; x \gt 0 \\ 0,  \; \rm{if} \;x \le 0 \end{cases} \end{equation}$$
On the other hand, when it seems a number, the input gate will be ${[0, 1]}^\intercal$, which prepares $c_{sum}$ for an update and protects $c_{depth}$ from being modified.

The forget gate can be kept open all the time, i.e., ${[1, 1]}^\intercal$ for our simply example.

The new data is calculated by:
$$\begin{equation} \mathbf{g} =\rm{sign} \left( \begin{bmatrix} 1 & -1 & 0 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix}\cdot \mathbf{x} \right). \end{equation}$$
Apply Equation 11 to Equation 7, we get the sequence:
$$\begin{equation} \begin{bmatrix}1\\0\end{bmatrix},  \begin{bmatrix}0\\2.5\end{bmatrix}, \begin{bmatrix}0\\4.2\end{bmatrix}, \begin{bmatrix}-1\\0\end{bmatrix}. \end{equation}$$

This means $\mathbf{g}$ is ${[1, 0]}^\intercal$ for "(", ${[-1, 0]}^\intercal$ for ")", and ${[0, x_v]}^\intercal$ for number inputs.

Therefore, the final update rules for $\mathbf{c}$ is the following (recall that forget gate simply just pass $\mathbf{c}_{t-1}$ through according to Figure 4:

$$\begin{equation} \begin{aligned}
c_{depth,t} &= \begin{cases} c_{depth,t-1} + 1, \rm{if \; input \; is \; (} \\
c_{depth,t-1} - 1, \rm{if \; input \; is \; )} \\
c_{depth,t-1}, \rm{if \; input \; is \; numeric}\end{cases}, \\[2ex]
c_{sum, t} &= \begin{cases} c_{sum, t-1}, \rm{if \; input \; is \; ( \; or )} \\
c_{sum, t-1}+x_v, \rm{if \; input \; is \; numeric} \end{cases}. \end{aligned} \end{equation}$$

So far, we are summing up all numbers read from the input, regardless of their depths, but our ultimate goal is to only sum the numbers when its depth is greater than one.  Since now the input gate for $c_{sum}$ should only be opened, when $c_{depth} \gt 1$, we link our output $\mathbf{h}$ back to the next time step to influence our input gate calculation.  Our new input gate calculation formula is updated from Equation 8 into the following:

$$\begin{equation} \mathbf{i} = \rm{sign} \left(\begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix}\cdot \mathbf{x} + \begin{bmatrix} 0 & 0 \\ 1 & 0 \end{bmatrix}\begin{bmatrix}c_{depth} \\ c_{sum}\end{bmatrix} - \begin{bmatrix} 0 \\ 2\end{bmatrix} \right). \end{equation}$$

The input gate for $c_{depth}$ behaves the same as before, as the two added terms have no effect on the top output element.  When the input is a number, the input gate for $c_{sum}$ is controlled by the value of $c_{depth}-1$, i.e., it only opens for update when $c_{depth} \gt 1$, exactly what we want.  The linkage of output to the next input brings the context (memory), so that the same input $\mathbf{x}_t$ produces different gate values and our cell is updated in a context-sensitive manner.

With the capability of longer-term memory, LSTM network can be used to deal with very long sequences.  Why LSTM can avoid the vanishing/exploding gradient problem that troubles the SimpleRNN?  A good explanation is offered by a blog here [9].  The main nested product involved in back-propagation is $\frac{\nabla c_{t}}{\nabla c_{t-1}}$.  Based on Figure 4, the addition term containing $\mathbf{i}$ and $\mathbf{g}$ has no contribution, the derivative is basically $\mathbf{f}$ - the forget gate.  $\mathbf{f}$ value is never greater than one, therefore, there is no exploding problem.  When memory needs to be retained, $\mathbf{f}$ values should be close to one, so the vanishing gradient problem is also diminished.  The gradient is terminated when old memory needs to be wiped out ($\mathbf{f}$ is set to zero).

Through the few virtual experiments discussed here, hopefully we now have more faith in the recurrent network.  As LSTM is so successful, when people say they use RNN, they actually refer to LSTM by default.  The traditional RNN is called SimpleRNN instead.  Another RNN variant call Gated Recurrent Unit (GRU) has shown to be simpler than LSTM and perform equally well, but we do not dig into that for now.  For readers speaking Chinese, a nice video on RNN is available here [10].

I am very excited by the RNN architecture, because it seems to explain how we dreams.  In the video here (20:00 - 31:00), it was shown that we can feed an RNN with Shakespeare, with C-code, or with mathematical articles and the RNN seems to be able to spit out shakespeare, C-code, and math articles.  The RNN architecture is outline in Figure 5.  Where one input character triggers the output of the next character, then the next character, etc.  Starting with a character, the RNN generates a sequence of characters.  If we feed the input with Shakespeare, the output generated by the RNN is then compared to the desirable output and the cross-entropy loss function can be used to train the weights in the RNN.  With enough iterations, the network starts to generate text that share the many syntax features on the training sequence.

Figure 5. The input is one character at a time. The character is one-hot encoded.  The hidden state is updated accordingly and then used to generate logits for the output characters (softmaxed into probabilities).  One output character spits out as the result.  This is a sequence-to-sequence application depicted in Figure 2.A. (credit: [11])
Figure 6 & 7 shows some sample outputs where RNN generate Shakespeare and Linux C code.
Figure 6.  With sufficient training samples, the RNN can generate something syntactically looks like Shakespeare. 
Figure 7.  With sufficient training samples, the RNN can generate something syntactically looks like C programming code. (Credit [11])

The generated output seems sensible at their local structures, but non-sensible at the longer range.  This conveniently explains how we dream, fragments of dreams seem quite logical and movie-like, however, these fragments are loosely connected, they do not serve the purpose of one theme.  When the fragments accidentally make some sense, you are extremely impressed by the power of dreaming.  It is like how people chat, moving from one topic to the next, logical within each transition, but accomplishing nothing as the whole.

When the elements within the hidden states were examined, some elements appear to represent interesting sequence pattern, similar to the depth and summation elements in our previous example.  Figure 8 shows a few examples of such neurons, this reminds us that RNN is able to train meaningful neurons just like CNN neurons can represent complex patterns, such as faces or texts.

Figure 8. Three example elements in the hidden state, where each captures a distinct sequence pattern.  A. A neuron that is activated when it encounters a left quote, then deactivates when right quote is encountered, very similar to our previous $C_depth$ neuron. B. A neuron appears to count the number of characters it reads and gets reset when line break is encountered.  C. A neuron that counts the indention depth of the C code.  (Credit: [11])


Reference



  1. http://lup.lub.lu.se/luur/download?func=downloadFile&recordOId=8911069&fileOId=8911070
  2. Aurelien Geron, Hands-on machine learning with scikit-learn & tensorflow. p373.
  3. Antonio Gulli, Deep Learning with Keras: Implementing deep learning models and neural networks with the power of Python, page 214.
  4. https://www.ted.com/talks/fei_fei_li_how_we_re_teaching_computers_to_understand_pictures
  5. Aurelien Geron, Hands-on machine learning with scikit-learn & tensorflow. p385.
  6. Aurelien Geron, Hands-on machine learning with scikit-learn & tensorflow. p403.
  7. Aurelien Geron, Hands-on machine learning with scikit-learn & tensorflow. p405.
  8. http://lstm.seas.harvard.edu/
  9. http://harinisuresh.com/2016/10/09/lstms/
  10. https://www.youtube.com/watch?v=xCGidAeyS4M
  11. http://cs231n.stanford.edu/slides/2016/winter1516_lecture10.pdf

No comments: