Recurrent neural network
- yasobel
- 18 déc. 2020
- 3 min de lecture
This post will explain how a RNN works.
What is a recurrent neural network ?
A RNN is a class of artificial neural networks that learns from a given database to predict an output. But as opposed to other networks, a RNN can be used for a sequential database. So, it can handle different size vectors (input and output), preserve sequential information and it has a good generalisation over sequences.
One common RNN network is Vanilla RNN that can work with sequences of variable length as we can see on the image :

And it's based on this recurrent process :

At each time step, we use the same function and parameters. As for the training, the network uses a "Backpropagation through time" (BPTT) process to update the weights, decrease the loss and increase the accuracy.
As opposed to a normal Backpropagation, this one does a more complex calculation of the gradient of the loss in which it sums the gradients over time steps.
The loss at a time step t is calculated based on the output and the cell state at the same time step, then the gradient of the cell state is calculated based on the gradients of the previous hidden states.
However, there are two problems linked to the Vanilla RNN :
-> if the factors are higher than 1the gradient loss explodes
-> if the factors are smaller than 1 the gradient loss vanishes
So in practice we use another method : the long short term memory (LSTM)
This method has the same chain-like structure but the cell architecture changes. To avoid the vanishing gradient problem, we change the way the network manages information and gradient flow.
The network is defined by 2 internal states : the hidden state (ht) and the cell state (ct) as we can see below

And four gates are used to control information received by ht and ct :
-> Forget gate ft : what information should we forget or keep from previous states ? -> Input gate it : which values will be written to current cell ? -> Input modulation gate gt : how much to write to current cell ? -> Output gate ot : what will be output from current cell
Same as for the Vanilla RNN , LSTM uses BPTT to calculate the loss and the weights but the loss is not only propagated by ht it's also propagated by ct which solves the vanishing gradient problem.
What are some of the applications of RNN ?
-> Language modelling :
We can use it to predict the next word of a sentence given its previous words. In this case, the input is the sentences, and for each hidden state one input is analysed at a time. BPTT is also used to update the weights and loss.
-> Image generation :
The network can also predict the value of the next pixels of a given image.
-> Sequence to sequence:
Used in translation for example. The input and output are both sequences, the encoder uses a many to one architecture to summarise the input and the decoder uses a one to many architecture that reads the encoder's output and gives the appropriate output in the target language
It can also be used in image captioning (like in the lab), video descriptions, hand gesture recognition etc
However, we always have to wait for the encoder to process all of the input data before producing a prediction. So it's not that practical for real-time predictions.
In order to counter this problem we add the attention mechanism to the RNN.
The attention can be defined as the ability to decide what to focus on, to be selective about what you're looking for or thinking about.
And the attention over time is called memory.
For the system to be able to focus on specific information, two types of attention are used :
-> Implicit attention :
Since the network has a memory of past events with the help of the recursive hidden states, when making a decision the network will use this knowledge.
The implicit attention can be measured by a sequential Jacobian.
-> Explicit attention :
This kind of attention is closest to how a human reacts and it has many advantages such as reducing computations, better understanding what the network is doing and simplifying the sequential processing.
To implement the attention mechanism in a RNN we follow the scheme below :

As we can see the attention model is applied on extra data and gives a feedback vector to the network at the next time step.
And finally, there are two types of attention models :
-> Hard :
Uses reinforcement learning to train the model.
-> Soft :
Uses backpropagation to train the model.
These attention models can be used for image generation, sentiment analysis, image processing or image captioning for example.


Commentaires