【问题标题】:Input data preparation for lstm/grulstm/gru 的输入数据准备
【发布时间】:2022-01-14 12:40:53
【问题描述】:

我在理解如何转换我的数据以提供给网络时遇到问题(我认为 lstm 网络会有所帮助,因为我的数据主要是时间序列类型并且还有一些时间信息,所以......)。

这里是数据格式 前 6 列代表一秒钟的数据(larger_corr、shorter_corr、noiseratio、x、y、z) 然后是相应的输出特征,然后是下一秒数据。

但是为了准备训练数据,我怎样才能发送 6 列数据,然后是接下来的 6 列。所有列的长度都是 40。

我不确定我是否表达得足够清楚

如果您需要任何其他信息,请告诉我。

【问题讨论】:

    标签: python tensorflow keras lstm data-preprocessing


    【解决方案1】:

    您可以尝试按如下方式准备数据,但请注意,为了确保可读性,我只使用了 12 列:

    import pandas as pd
    import numpy as np
    import tensorflow as tf
    import tabulate
    np.random.seed(0)
    
    df = pd.DataFrame({
        'larger_corr' : np.random.randn(25),
        'shorter_corr' : np.random.randn(25),
        'noiseratio' : np.random.randn(25),
        'x' : np.random.randn(25),
        'y' : np.random.randn(25),
        'z' : np.random.randn(25),
        'output' : np.random.randint(0,2,25),
        'larger_corr.1' : np.random.randn(25),
        'shorter_corr.1' : np.random.randn(25),
        'noiseratio.1' : np.random.randn(25),
        'x.1' : np.random.randn(25),
        'y.1' : np.random.randn(25),
        'z.1' : np.random.randn(25),
        'output.1' : np.random.randint(0,2,25)
    })
    
    print(df.to_markdown())
    y1, y2 = df.pop('output').to_numpy(), df.pop('output.1').to_numpy()
    data = df.to_numpy()
    x1, x2 = np.array_split(data, 2, axis=1)
    x1 = np.expand_dims(x1, axis=1) # add timestep dimension
    x2 = np.expand_dims(x2, axis=1) # add timestep dimension
    X = np.concatenate([x1, x2])
    Y = np.concatenate([y1, y1])
    print('Shape of X -->', X.shape, 'Shape of labels -->', Y.shape)
    
    |    |   larger_corr |   shorter_corr |   noiseratio |          x |         y |          z |   output |   larger_corr.1 |   shorter_corr.1 |   noiseratio.1 |        x.1 |        y.1 |         z.1 |   output.1 |
    |---:|--------------:|---------------:|-------------:|-----------:|----------:|-----------:|---------:|----------------:|-----------------:|---------------:|-----------:|-----------:|------------:|-----------:|
    |  0 |      1.76405  |     -1.45437   |   -0.895467  | -0.68481   |  1.88315  | -0.149635  |        1 |       0.438871  |       -0.244179  |     -0.891895  | -0.617166  |  1.14367   | -0.936916   |          0 |
    |  1 |      0.400157 |      0.0457585 |    0.386902  | -0.870797  | -1.34776  | -0.435154  |        1 |       0.63826   |        0.475261  |      0.570081  | -1.77556   | -0.188056  | -1.97935    |          0 |
    |  2 |      0.978738 |     -0.187184  |   -0.510805  | -0.57885   | -1.27048  |  1.84926   |        0 |       2.01584   |       -0.714216  |      2.66323   | -1.11821   |  1.24678   |  0.445384   |          0 |
    |  3 |      2.24089  |      1.53278   |   -1.18063   | -0.311553  |  0.969397 |  0.672295  |        0 |      -0.243653  |       -1.18694   |      0.410289  | -1.60639   | -0.253884  | -0.195333   |          1 |
    |  4 |      1.86756  |      1.46936   |   -0.0281822 |  0.0561653 | -1.17312  |  0.407462  |        1 |       1.53384   |        0.608891  |      0.485652  | -0.814676  | -0.870176  | -0.202716   |          1 |
    |  5 |     -0.977278 |      0.154947  |    0.428332  | -1.16515   |  1.94362  | -0.769916  |        1 |       0.76475   |        0.504223  |      1.31153   |  0.321281  |  0.0196537 |  0.219389   |          0 |
    |  6 |      0.950088 |      0.378163  |    0.0665172 |  0.900826  | -0.413619 |  0.539249  |        0 |      -2.45668   |       -0.513996  |     -0.235649  | -0.12393   | -1.11437   | -1.03016    |          0 |
    |  7 |     -0.151357 |     -0.887786  |    0.302472  |  0.465662  | -0.747455 | -0.674333  |        1 |      -1.70365   |        0.818475  |     -1.48018   |  0.0221213 |  0.607842  | -0.929744   |          0 |
    |  8 |     -0.103219 |     -1.9808    |   -0.634322  | -1.53624   |  1.92294  |  0.0318306 |        1 |       0.420153  |        1.1566    |     -0.0214848 | -0.321287  |  0.457237  | -2.55857    |          1 |
    |  9 |      0.410599 |     -0.347912  |   -0.362741  |  1.48825   |  1.48051  | -0.635846  |        1 |      -0.298149  |       -0.803689  |      1.05279   |  0.692618  |  0.875539  |  1.6495     |          0 |
    | 10 |      0.144044 |      0.156349  |   -0.67246   |  1.89589   |  1.86756  |  0.676433  |        1 |       0.263602  |       -0.551562  |     -0.117402  | -0.353524  |  0.346481  |  0.611738   |          0 |
    | 11 |      1.45427  |      1.23029   |   -0.359553  |  1.17878   |  0.906045 |  0.576591  |        1 |       0.731266  |       -0.332414  |      1.82851   |  0.81229   | -0.454874  | -1.05194    |          1 |
    | 12 |      0.761038 |      1.20238   |   -0.813146  | -0.179925  | -0.861226 | -0.208299  |        1 |       0.22807   |        1.84452   |     -0.0166771 | -1.14179   |  0.198095  | -0.754946   |          0 |
    | 13 |      0.121675 |     -0.387327  |   -1.72628   | -1.07075   |  1.91006  |  0.396007  |        0 |      -2.02852   |       -0.422776  |      1.87011   | -0.287549  |  0.391408  |  0.623188   |          1 |
    | 14 |      0.443863 |     -0.302303  |    0.177426  |  1.05445   | -0.268003 | -1.09306   |        0 |       0.96619   |        0.487659  |     -0.380307  |  1.31554   | -3.17786   |  0.00470758 |          0 |
    | 15 |      0.333674 |     -1.04855   |   -0.401781  | -0.403177  |  0.802456 | -1.49126   |        1 |      -0.186922  |       -0.375828  |      0.428698  |  0.685781  | -0.956575  | -0.899891   |          0 |
    | 16 |      1.49408  |     -1.42002   |   -1.6302    |  1.22245   |  0.947252 |  0.439392  |        0 |      -0.472325  |        0.227851  |      0.361896  |  0.524599  | -0.0312749 |  0.129242   |          1 |
    | 17 |     -0.205158 |     -1.70627   |    0.462782  |  0.208275  | -0.15501  |  0.166673  |        1 |       1.93666   |        0.703789  |      0.467568  | -0.793387  |  1.03272   |  0.979693   |          1 |
    | 18 |      0.313068 |      1.95078   |   -0.907298  |  0.976639  |  0.614079 |  0.635031  |        0 |       1.47734   |       -0.7978    |     -1.51803   | -0.237881  | -1.21562   |  0.328375   |          0 |
    | 19 |     -0.854096 |     -0.509652  |    0.0519454 |  0.356366  |  0.922207 |  2.38314   |        0 |      -0.0848901 |       -0.6759    |     -1.89304   |  0.569498  | -0.318678  |  0.487074   |          0 |
    | 20 |     -2.55299  |     -0.438074  |    0.729091  |  0.706573  |  0.376426 |  0.944479  |        1 |       0.427697  |       -0.922546  |     -0.785087  | -1.51061   |  1.49513   |  0.144842   |          1 |
    | 21 |      0.653619 |     -1.2528    |    0.128983  |  0.0105    | -1.0994   | -0.912822  |        1 |      -0.30428   |       -0.448586  |     -1.60529   | -1.56505   | -0.130251  | -0.0856099  |          1 |
    | 22 |      0.864436 |      0.77749   |    1.1394    |  1.78587   |  0.298238 |  1.11702   |        1 |       0.204625  |        0.181979  |      1.43184   | -3.05123   | -1.20289   |  0.71054    |          1 |
    | 23 |     -0.742165 |     -1.6139    |   -1.23483   |  0.126912  |  1.32639  | -1.31591   |        1 |      -0.0833382 |       -0.220084  |     -1.94219   |  1.55966   |  0.199565  |  0.93096    |          0 |
    | 24 |      2.26975  |     -0.21274   |    0.402342  |  0.401989  | -0.694568 | -0.461585  |        1 |       1.82893   |        0.0249562 |      1.13995   | -2.63101   |  0.393166  |  0.875074   |          0 |
    Shape of X --> (50, 1, 6) Shape of labels --> (50,)
    

    预处理你的数据后,你可以像这样创建一个LSTM模型,其中维度timesteps代表1秒:

    timesteps, features = X.shape[1], X.shape[2]
    input = tf.keras.layers.Input(shape=(timesteps, features))
    x = tf.keras.layers.LSTM(32, return_sequences=False)(input)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)
    model = tf.keras.Model(input, output)
    model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy())
    print(model.summary())
    model.fit(X, Y, batch_size=10, epochs=5)
    
    Model: "model_1"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     input_16 (InputLayer)       [(None, 1, 6)]            0         
                                                                     
     lstm_1 (LSTM)               (None, 32)                4992      
                                                                     
     dense_21 (Dense)            (None, 1)                 33        
                                                                     
    =================================================================
    Total params: 5,025
    Trainable params: 5,025
    Non-trainable params: 0
    _________________________________________________________________
    None
    Epoch 1/5
    5/5 [==============================] - 2s 4ms/step - loss: 0.6914
    Epoch 2/5
    5/5 [==============================] - 0s 3ms/step - loss: 0.6852
    Epoch 3/5
    5/5 [==============================] - 0s 3ms/step - loss: 0.6806
    Epoch 4/5
    5/5 [==============================] - 0s 4ms/step - loss: 0.6758
    Epoch 5/5
    5/5 [==============================] - 0s 4ms/step - loss: 0.6705
    <keras.callbacks.History at 0x7f90ca6c6d90>
    

    您还可以在使用 MinMaxScalerStandardScaler 将数据输入模型之前对其进行缩放/规范化,但我将由您来决定。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-02-10
      • 1970-01-01
      • 2016-07-05
      • 2019-01-21
      • 2016-11-09
      • 1970-01-01
      • 1970-01-01
      • 2021-12-20
      相关资源
      最近更新 更多