Javatpoint Logo
Javatpoint Logo

Long short-term memory (LSTM) RNN in Tensorflow

Long short-term memory (LSTM) is an artificial recurrent neural network (RNN) architecture used in the field of deep learning. It was proposed in 1997 by Sepp Hochreiter and Jurgen schmidhuber. Unlike standard feed-forward neural networks, LSTM has feedback connections. It can process not only single data points (such as images) but also entire sequences of data (such as speech or video).

For example, LSTM is an application to tasks such as unsegmented, connected handwriting recognition, or speech recognition.

A general LSTM unit is composed of a cell, an input gate, an output gate, and a forget gate. The cell remembers values over arbitrary time intervals, and three gates regulate the flow of information into and out of the cell. LSTM is well-suited to classify, process, and predict the time series given of unknown duration.

Long Short- Term Memory (LSTM) networks are a modified version of recurrent neural networks, which makes it easier to remember past data in memory.
Long short-term memory RNN in Tensorflow

1. Input gate- It discover which value from input should be used to modify the memory. Sigmoid function decides which values to let through 0 or 1. And tanh function gives weightage to the values which are passed, deciding their level of importance ranging from -1 to 1.

Long short-term memory RNN in Tensorflow

2. Forget gate- It discover the details to be discarded from the block. A sigmoid function decides it. It looks at the previous state (ht-1) and the content input (Xt) and outputs a number between 0(omit this) and 1(keep this) for each number in the cell state Ct-1.

Long short-term memory RNN in Tensorflow

3. Output gate- The input and the memory of the block are used to decide the output. Sigmoid function decides which values to let through 0 or 1. And tanh function decides which values to let through 0, 1. And tanh function gives weightage to the values which are passed, deciding their level of importance ranging from -1 to 1 and multiplied with an output of sigmoid.

Long short-term memory RNN in Tensorflow
Long short-term memory RNN in Tensorflow

It represents a full RNN cell that takes the current input of the sequence xi, and outputs the current hidden state, hi, passing this to the next RNN cell for our input sequence. The inside of an LSTM cell is a lot more complicated than a traditional RNN cell, while the conventional RNN cell has a single "internal layer" acting on the current state (ht-1) and input (xt).

Long short-term memory RNN in Tensorflow

In the above diagram, we see an "unrolled" LSTM network with an embedding layer, a subsequent LSTM layer, and a sigmoid activation function. We recognize that our inputs, in this case, words in a movie review, are input sequentially.

The words are inputted into an embedding lookup. In most cases, when working with a corpus of text data, the size of the vocabulary is unusually large.

This is a multidimensional, distributed representation of words in a vector space. These embeddings can be learned using other deep learning techniques like word2vec, we can train the model in an end-to-end fashion to determine the embedding as we teach.

These embeddings are then inputted into our LSTM layer, where the output is fed to a sigmoid output layer and the LSTM cell for the next word in our sequence.

LSTM Layers

We will set up a function to build the LSTM layers to handle the number of layers and sizes dynamically. The service will take a list of LSTM sizes, which can indicate the number of LSTM layers based on the list's length (e.g., our example will use a list of length 2, containing the sizes 128 and 64, indicating a two-layered LSTM network where the first layer size 128 and the second layer has hidden layer size 64).

The list of dropout wrapped LSTMs are then passed to a TensorFlow MultiRNN cell to stack the layers together.

Loss function, optimizer and accuracy

Finally, we create functions to define our model loss function, optimizer, and our accuracy. Even though the loss and accuracy are just calculated based on results, In TensorFlow everything is part of a computation graph.

Building the graph and training

First, we call each of the functions we have defined to construct the network and call a TensorFlow session to train the model over a predefined number of epochs using mini-batches. At the end of every epoch, we will print the loss, training accuracy, and validation accuracy to monitoring the results as we train the model.

Next, we define our model hyperparameters, and we will build a two-layer LSTM network with hidden layer sizes of 128 and 64, respectively.

When the model is done training, we use a TensorFlow saver to save out the model parameters for later use.

Epoch: 1/50 Batch: 303/303 Train Loss: 0.247 Train Accuracy: 0.562 Val Accuracy: 0.578
Epoch: 2/50 Batch: 303/303 Train Loss: 0.245 Train Accuracy: 0.583 Val Accuracy: 0.596
Epoch: 3/50 Batch: 303/303 Train Loss: 0.247 Train Accuracy: 0.597 Val Accuracy: 0.617
Epoch: 4/50 Batch: 303/303 Train Loss: 0.240 Train Accuracy: 0.610 Val Accuracy: 0.627
Epoch: 5/50 Batch: 303/303 Train Loss: 0.238 Train Accuracy: 0.620 Val Accuracy: 0.632
Epoch: 6/50 Batch: 303/303 Train Loss: 0.234 Train Accuracy: 0.632 Val Accuracy: 0.642
Epoch: 7/50 Batch: 303/303 Train Loss: 0.230 Train Accuracy: 0.636 Val Accuracy: 0.648
Epoch: 8/50 Batch: 303/303 Train Loss: 0.227 Train Accuracy: 0.641 Val Accuracy: 0.653
Epoch: 9/50 Batch: 303/303 Train Loss: 0.223 Train Accuracy: 0.646 Val Accuracy: 0.656
Epoch: 10/50 Batch: 303/303 Train Loss: 0.221 Train Accuracy: 0.652 Val Accuracy: 0.659


Finally, we check our model results on the test set to make sure they are in line with what we observed during training.

The test accuracy is 72%. This is right in line with our validation accuracy and indicates that we captured in an appropriate distribution of our data across our data splitting.

INFO:tensorflow:Restoring parameters from checkpoints/sentiment.ckpt
Test Accuracy: 0.717

Next TopicTraining of RNN

Help Others, Please Share

facebook twitter google plus pinterest

Learn Latest Tutorials


Trending Technologies

B.Tech / MCA