dynamic_rnn转nn.GRU详细记录

前言

今天在将一份tensorflow的代码转为pytorch时遇到的一点困难,经过多次debug以后终于弄清楚了这里应该是如何进行转换的,因此记录下来。

直接上代码吧,为了确保最终的结果是一致的,这里我将网络层的权重全部初始化为0。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializers

input = np.random.rand(3, 1, 5)
hidden = np.random.rand(3, 5)

print("input: ", input.shape)
print(input)
print("hidden: ", hidden.shape)
print(hidden)

print("="*20, ' tensorflow result ', "="*20)
# cell with zeros initializer
cell = tf.compat.v1.nn.rnn_cell.GRUCell(5, kernel_initializer=initializers.Zeros(), bias_initializer=initializers.Zeros())
tf_output, tf_state = tf.compat.v1.nn.dynamic_rnn(cell, input, initial_state=hidden)
print(tf_output) # (batch size, time steps, features)
print(tf_state) # (batch size, features) for the final time steps
print('\n')

print("="*20, ' rnn cell result ', "="*20)
# rnn cell
pytorch_rnn_cell = nn.GRUCell(5, 5)
for k, v in pytorch_rnn_cell.state_dict().items():
torch.nn.init.constant_(v, 0)
pytorch_input_cell = torch.from_numpy(input).permute(1, 0, 2).float() # (time steps, batch size, features)
pytorch_hidden_cell = torch.from_numpy(hidden).float() # (batch size, features)
pytorch_output_cell = []
for i in range(1):
pytorch_hidden_cell = pytorch_rnn_cell(pytorch_input_cell[i], pytorch_hidden_cell)
pytorch_output_cell.append(pytorch_hidden_cell)
print(pytorch_output_cell)
print('\n')

print("="*20, ' rnn result ', "="*20)
# rnn
pytorch_rnn = nn.GRU(5, 5)
for k, v in pytorch_rnn.state_dict().items():
torch.nn.init.constant_(v, 0)
pytorch_input = torch.from_numpy(input).permute(1, 0, 2).float() # (time steps, batch size, feature size)
pytorch_hidden = torch.from_numpy(hidden).unsqueeze(0).float() # (time steps, batch size, hidden size)
pytorch_output, pytorch_state = pytorch_rnn(pytorch_input, pytorch_hidden)
print(pytorch_output, pytorch_output.shape)
print(pytorch_state, pytorch_state.shape)

最后的结果如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
input:  (3, 1, 5)
[[[0.98175333 0.59281082 0.47678967 0.70612923 0.73616147]]

[[0.8363702 0.85099391 0.75740424 0.30633335 0.20097122]]

[[0.60316062 0.21921029 0.16052985 0.25654177 0.40698399]]]
hidden: (3, 5)
[[0.46976021 0.19681885 0.59240364 0.79540728 0.27608136]
[0.39461795 0.29340918 0.4515729 0.6921841 0.44068605]
[0.89315058 0.72514622 0.2925488 0.45433305 0.59910906]]
==================== tensorflow result ====================
tf.Tensor(
[[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]]

[[0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]]

[[0.44657529 0.36257311 0.1462744 0.22716653 0.29955453]]], shape=(3, 1, 5), dtype=float64)
tf.Tensor(
[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]
[0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]
[0.44657529 0.36257311 0.1462744 0.22716653 0.29955453]], shape=(3, 5), dtype=float64)


==================== rnn cell result ====================
[tensor([[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],
[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],
[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]], grad_fn=<AddBackward0>)]


==================== rnn result ====================
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],
[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],
[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],
[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],
[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])

Process finished with exit code 0