【发布时间】:2022-01-22 22:25:15
【问题描述】:
我尝试使用 keras 模型进行预测,但在使用 fit 时遇到了问题。 我的目标是对 BNB/USDT 股票进行 30 分钟的预测
我得到的错误是
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Incompatible shapes: [32,30] vs. [32,30,1]
[[{{node loss/dense_loss/SquaredDifference}}]]
[[training/Adam/gradients/gradients/lstm_1/while/ReadVariableOp/Enter_grad/b_acc_3/_125]]
(1) Invalid argument: Incompatible shapes: [32,30] vs. [32,30,1]
[[{{node loss/dense_loss/SquaredDifference}}]]
这是代码
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM
from binance.client import Client
import csv
import tensorflow as tf
pd.options.mode.chained_assignment = None
tf.random.set_random_seed(0)
api = {'key':'...','secret':'...'}
# client = Client(api['key'], api['secret'])
# length_data = "2 day"
# klines = client.get_historical_klines("BNBUSDT", Client.KLINE_INTERVAL_1MINUTE, length_data + " UTC")
# with open('./bnbusdt_price_train_test.csv', 'w') as f:
# writer = csv.writer(f)
# writer.writerow(['timestamp','open','max','min','close'])
# for sub in klines:
# writer.writerow([sub[0], sub[1], sub[2], sub[3], sub[4]])
df = pd.read_csv('./bnbusdt_price_train_test.csv')
df['Date'] = pd.to_datetime(df.timestamp, unit='ms')
df.sort_values('Date')
y = df['close'].fillna(method='ffill')
y = y.values.reshape(-1, 1)
scaler = MinMaxScaler(feature_range=(0, 1))
scaler = scaler.fit(y)
y = scaler.transform(y)
n_lookback = 60
n_forecast = 30
X = []
Y = []
for i in range(n_lookback, len(y) - n_forecast + 1):
X.append(y[i - n_lookback: i])
Y.append(y[i: i + n_forecast])
X = np.array(X)
Y = np.array(Y)
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(n_lookback, 1)))
model.add(LSTM(units=50))
model.add(Dense(n_forecast))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(X, Y, epochs=1, batch_size=32, verbose=0)
我加载的 CSV 包含:
- 时间戳(毫秒)
- 开盘价
- 最高价格
- 最低价格
- 收盘价
我尝试将我的 3d 输入更改为 2d,但在 model.add 上出现另一个错误
你有什么想法吗?
【问题讨论】:
标签: python pandas numpy tensorflow keras