Skip to content

Commit de54403

Browse files
committed
fix the input shape to match (:, timesteps, feature) of keras' lstm
1 parent d3bb6b7 commit de54403

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

StockPricesPredictionProject/pricePredictionLSTM.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ def create_dataset(dataset, look_back=1):
5050
testX, testY = create_dataset(test, look_back)
5151

5252
# reshape input to be [samples, time steps, features]
53-
trainX = np.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
54-
testX = np.reshape(testX, (testX.shape[0], 1, testX.shape[1]))
53+
trainX = np.reshape(trainX, (trainX.shape[0], trainX.shape[1], 1))
54+
testX = np.reshape(testX, (testX.shape[0], testX.shape[1], 1))
5555

5656
# create and fit the LSTM network, optimizer=adam, 25 neurons, dropout 0.1
5757
model = Sequential()
58-
model.add(LSTM(25, input_shape=(1, look_back)))
58+
model.add(LSTM(25, input_shape=(look_back, 1)))
5959
model.add(Dropout(0.1))
6060
model.add(Dense(1))
6161
model.compile(loss='mse', optimizer='adam')

0 commit comments

Comments
 (0)