Skip to content

Commit 7c57e52

Browse files
fuzihaofzhsoumith
authored andcommitted
Add a time sequence prediction example (pytorch#118)
1 parent ac5b745 commit 7c57e52

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

time_sequence_prediction/README.md

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Time Sequence Prediction
2+
This is a toy example for beginners to start with. It is helpful for learning both pytorch and time sequence prediction. Two LSTMCell units are used in this example to learn some sine wave signals starting at different phases. After learning the sine waves, the network tries to predict the signal values in the future. The results is shown in the picture below.
3+
4+
## Usage
5+
6+
```
7+
python generate_sine_wave.py
8+
python train.py
9+
```
10+
11+
## Result
12+
The initial signal and the predicted results are shown in the image. We first give some initial signals (full line). The network will subsequently give some predicted results (dash line). It can be concluded that the network can generate new sine waves.
13+
![image](https://cloud.githubusercontent.com/assets/1419566/24184438/e24f5280-0f08-11e7-8f8b-4d972b527a81.png)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import math
2+
import numpy as np
3+
import torch
4+
T = 20
5+
L = 1000
6+
N = 100
7+
np.random.seed(2)
8+
x = np.empty((N, L), 'int64')
9+
x[:] = np.array(range(L)) + np.random.randint(-4*T, 4*T, N).reshape(N, 1)
10+
data = np.sin(x / 1.0 / T).astype('float64')
11+
torch.save(data, open('traindata.pt', 'wb'))
12+

time_sequence_prediction/train.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import print_function
2+
import torch
3+
import torch.nn as nn
4+
from torch.autograd import Variable
5+
import torch.optim as optim
6+
import numpy as np
7+
import matplotlib
8+
matplotlib.use('Agg')
9+
import matplotlib.pyplot as plt
10+
11+
class Sequence(nn.Module):
12+
def __init__(self):
13+
super(Sequence, self).__init__()
14+
self.lstm1 = nn.LSTMCell(1, 51)
15+
self.lstm2 = nn.LSTMCell(51, 1)
16+
17+
def forward(self, input, future = 0):
18+
outputs = []
19+
h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
20+
c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
21+
h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
22+
c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
23+
24+
for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
25+
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
26+
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
27+
outputs += [c_t2]
28+
for i in range(future):# if we should predict the future
29+
h_t, c_t = self.lstm1(c_t2, (h_t, c_t))
30+
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
31+
outputs += [c_t2]
32+
outputs = torch.stack(outputs, 1).squeeze(2)
33+
return outputs
34+
35+
36+
37+
if __name__ == '__main__':
38+
# set ramdom seed to 0
39+
np.random.seed(0)
40+
torch.manual_seed(0)
41+
# load data and make training set
42+
data = torch.load(open('traindata.pt'))
43+
input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False)
44+
target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False)
45+
# build the model
46+
seq = Sequence()
47+
seq.double()
48+
criterion = nn.MSELoss()
49+
# use LBFGS as optimizer since we can load the whole data to train
50+
optimizer = optim.LBFGS(seq.parameters())
51+
#begin to train
52+
for i in range(15):
53+
print('STEP: ', i)
54+
def closure():
55+
optimizer.zero_grad()
56+
out = seq(input)
57+
loss = criterion(out, target)
58+
print('loss:', loss.data.numpy()[0])
59+
loss.backward()
60+
return loss
61+
optimizer.step(closure)
62+
# begin to predict
63+
future = 1000
64+
pred = seq(input[:3], future = future)
65+
y = pred.data.numpy()
66+
# draw the result
67+
plt.figure(figsize=(30,10))
68+
plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)
69+
plt.xlabel('x', fontsize=20)
70+
plt.ylabel('y', fontsize=20)
71+
plt.xticks(fontsize=20)
72+
plt.yticks(fontsize=20)
73+
def draw(yi, color):
74+
plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth = 2.0)
75+
plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth = 2.0)
76+
draw(y[0], 'r')
77+
draw(y[1], 'g')
78+
draw(y[2], 'b')
79+
plt.savefig('predict%d.pdf'%i)
80+
plt.close()
81+

0 commit comments

Comments
 (0)