Deep Learning PyTorch Course, Implementing GRU Cell

In deep learning, Recurrent Neural Networks (RNNs) are widely used to model sequential data such as time series data or natural language processing. Among these, the Gated Recurrent Unit (GRU) is a variant of RNN developed to address the long-term dependency problem, and it has a similar structure to Long Short-Term Memory (LSTM). In this post, we will explain the fundamental concepts of GRU and how to implement it using PyTorch.

1. What is GRU?

GRU is a structure proposed by Kyunghyun Cho in 2014 that operates in a simpler and less computationally intense manner by combining input information and previous state information to determine the current state. GRU uses two primary gates:

  • Reset Gate: Determines how much to reduce the influence of previous information.
  • Update Gate: Determines how much of the previous state to reflect.

The main equations of GRU are as follows:

1.1 Equation Definition

1. For the input vector x_t and the previous hidden state h_{t-1}, we define the reset gate r_t and the update gate z_t.

r_t = σ(W_r * x_t + U_r * h_{t-1})
z_t = σ(W_z * x_t + U_z * h_{t-1})

Here, W_r, W_z are weight parameters, and U_r, U_z are weights related to the previous state. σ is the sigmoid function.

2. The new hidden state h_t is computed as follows.

h_t = (1 - z_t) * h_{t-1} + z_t * tanh(W_h * x_t + U_h * (r_t * h_{t-1}))

Here, W_h, U_h are additional weights.

2. Advantages of GRU

  • With a simpler structure, it has fewer parameters than LSTM, allowing for faster training.
  • Due to its ability to learn long-term dependencies well, it performs excellently in various NLP tasks.

3. Implementing GRU Cell

Now, let’s implement the GRU cell using PyTorch. The sample code below demonstrates the basic operation of a GRU clearly.

3.1 GRU Cell Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class GRUSimple(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRUSimple, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Weight initialization
        self.Wz = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.Uz = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.Wr = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.Ur = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.Wh = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.Uh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        for param in self.parameters():
            stdv = 1.0 / param.size(0) ** 0.5
            param.data.uniform_(-stdv, stdv)

    def forward(self, x_t, h_prev):
        r_t = torch.sigmoid(self.Wr @ x_t + self.Ur @ h_prev)
        z_t = torch.sigmoid(self.Wz @ x_t + self.Uz @ h_prev)
        h_hat_t = torch.tanh(self.Wh @ x_t + self.Uh @ (r_t * h_prev))
        
        h_t = (1 - z_t) * h_prev + z_t * h_hat_t
        return h_t

The code above implements the structure of a simple GRU cell. The __init__ method initializes the input size and hidden state size, defining the weight parameters. The reset_parameters method initializes the weights. In the forward method, the new hidden state is calculated based on the input and the previous state.

3.2 Testing GRU Cell

Now, let’s write a sample code to test the GRU cell.

input_size = 5
hidden_size = 3
x_t = torch.randn(input_size)  # Generate random input
h_prev = torch.zeros(hidden_size)  # Initial hidden state

gru_cell = GRUSimple(input_size, hidden_size)
h_t = gru_cell(x_t, h_prev)

print("Current hidden state h_t:", h_t)

The above code allows us to check the operation of the GRU cell. It generates random input, sets the initial hidden state to 0, and then outputs the current hidden state h_t through the GRU cell.

4. RNN Model Using GRU

Now, let’s build the RNN model as a whole using the GRU cell.

class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUModel, self).__init__()
        self.gru = GRUSimple(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h_t = torch.zeros(self.gru.hidden_size)  # Initial hidden state

        for t in range(x.size(0)):
            h_t = self.gru(x[t], h_t)  # Use GRU at each time step
        output = self.fc(h_t)  # Convert the last hidden state to output
        return output

The GRUModel class above constructs a model that processes sequential data using the GRU cell. The forward method iterates through the input sequence and uses the GRU cell to update the hidden state. The last hidden state is used to generate the final output through a linear combination.

4.1 Testing RNN Model

Now, let’s test the GRU model.

input_size = 5
hidden_size = 3
output_size = 2
seq_length = 10

x = torch.randn(seq_length, input_size)  # Generate random sequence data

model = GRUModel(input_size, hidden_size, output_size)
output = model(x)

print("Model output:", output)

The code above allows us to observe the process in which the GRU model generates output for the given sequence data.

5. Application of GRU

GRU is utilized in various fields. In particular, it is effectively used in natural language processing (NLP) tasks, including machine translation, sentiment analysis, text generation, and many other applications. Recurrent structures like GRU provide powerful advantages in modeling continuous temporal dependencies.

Since GRU often demonstrates good performance while being simpler than LSTM, it is essential to make an appropriate choice based on the characteristics of the data and the nature of the problem.

6. Conclusion

In this post, we explored the fundamental concepts of GRU and its implementation of the GRU cell and RNN model using PyTorch. GRU is a useful structure for processing complex sequential data and can be integrated into various deep learning models to advance applications. Understanding GRU provides insights into natural language processing and time series analysis and helps in solving practical problems that may arise.

Now, we hope you will also apply GRU to your projects!

Author: Deep Learning Researcher

Date: October 2023