We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b156a6a commit 2c57b00Copy full SHA for 2c57b00
time_sequence_prediction/train.py
@@ -22,7 +22,7 @@ def forward(self, input, future = 0):
22
h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
23
c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
24
25
- for input_t in input.chunk(input.size(1), dim=1):
+ for input_t in input.split(1, dim=1):
26
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
27
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
28
output = self.linear(h_t2)
0 commit comments