【问题标题】:train_test_split crashing RAM when feeding with (big) numpy arraystrain_test_split 在使用(大)numpy 数组时崩溃 RAM
【发布时间】:2019-12-20 22:55:53
【问题描述】:

来自 Scikit-learn 的train_test_split 方法在使用形状为 (5621, 224, 224, 3)y 的形状为 (5621, 3) 的 numpy 数组输入 X 时会导致 RAM 崩溃并终止执行。

  • X 包含 5621 张 224x224 RGB 数据的图像。
  • y 包含 3 个类别的 5621 个 OneHot 编码标签。

我正在加载一些图像作为训练数据来提供卷积神经网络,但是在拆分为训练数据和测试数据时它崩溃了。是否有其他选项可以加载图像以避免这种内存消耗?

重现步骤:

import numpy as np
from sklearn.model_selection import train_test_split

# Generate dummy data
X = np.random.random((5621, 224, 224, 3))
y = np.random.randint(3, size=(5621, 3))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, shuffle=True) # Breaks here

我希望输出 3766 个训练样本和 1855 个测试样本,但它会发送 SIGKILL(以及 100% 的 RAM 使用率)并退出执行。

【问题讨论】:

    标签: numpy scikit-learn


    【解决方案1】:

    您确定它是在拆分方法上还是在之前已经发生过?

    您也可以手动拆分:

    X_train = X[:int(len(X)*.8)]
    y_train = y[:int(len(X)*.8)]
    X_test = X[int(len(X)*.8):]
    y_test y[int(len(X)*.8):]
    

    您的数据已经是随机的,所以应该不是排序问题。

    【讨论】:

    • 它不是因为 split 方法而崩溃,而是 numpy 数组的大小。我做了一些测试,加载 Xy 数组大约需要 7GB 的 RAM,因此调用 train_test_split(生成它们的另一个副本)达到 15GB 的已用 RAM,用 Linux free -h
    猜你喜欢
    • 1970-01-01
    • 2015-11-17
    • 2016-10-22
    • 2017-07-18
    • 2020-09-16
    • 2011-07-17
    • 2016-01-24
    • 2015-09-25
    • 2020-08-07
    相关资源
    最近更新 更多