【问题标题】:writing a train_test_split function with numpy用 numpy 编写一个 train_test_split 函数
【发布时间】:2016-06-26 05:25:45
【问题描述】:

我正在尝试使用 numpy 而不是使用 sklearn 的 train_test_split 函数编写我自己的训练测试拆分函数。我将数据分成 70% 的训练和 30% 的测试。我正在使用来自 sklearn 的波士顿住房数据集。

这是数据的形状:

housing_features.shape #(506,13) where 506 is sample size and it has 13 features.

这是我的代码:

city_data = datasets.load_boston()
housing_prices = city_data.target
housing_features = city_data.data

def shuffle_split_data(X, y):
    split = np.random.rand(X.shape[0]) < 0.7

    X_Train = X[split]
    y_Train = y[split]
    X_Test =  X[~split]
    y_Test = y[~split]

    print len(X_Train), len(y_Train), len(X_Test), len(y_Test)
    return X_Train, y_Train, X_Test, y_Test

try:
    X_train, y_train, X_test, y_test = shuffle_split_data(housing_features, housing_prices)
    print "Successful"
except:
    print "Fail"

我得到的打印输出是:

362 362 144 144
"Successful"

但我知道这并不成功,因为当我再次运行它时,我得到了不同的长度数字,而不是仅使用 SKlearn 的训练测试功能,X_train 的长度总是得到 354。

#correct output
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(housing_features, housing_prices, test_size=0.3, random_state=42)
print len(X_train) 
#354 

我缺少什么我的功能?

【问题讨论】:

    标签: python numpy scikit-learn


    【解决方案1】:

    因为您使用的是np.random.rand,它会为您提供随机数,对于非常大的数字,0.7 限制将接近 70%。您可以使用 np.percentile 来获得 70% 的值,然后像您所做的那样与该值进行比较:

    def shuffle_split_data(X, y):
        arr_rand = np.random.rand(X.shape[0])
        split = arr_rand < np.percentile(arr_rand, 70)
    
        X_train = X[split]
        y_train = y[split]
        X_test =  X[~split]
        y_test = y[~split]
    
        print len(X_Train), len(y_Train), len(X_Test), len(y_Test)
        return X_train, y_train, X_test, y_test
    

    编辑

    或者,您可以使用np.random.choice 选择具有所需数量的索引。对于您的情况:

    np.random.choice(range(X.shape[0]), int(0.7*X.shape[0]))
    

    【讨论】:

    • 另一方面,我应该使用随机数吗?因为 X_train 不应该对应于 y_train 值吗?或者即使在使用随机时也能保持这种结构?
    • @jxn 你应该使用随机,因为在原始train_test_split 你有random_state 这意味着随机输出。当然X_train 对应于y_train,因为您为它们使用了相同的掩码。
    • @jxn 你也可以使用np.random.choice
    猜你喜欢
    • 2020-09-16
    • 1970-01-01
    • 2022-08-17
    • 1970-01-01
    • 2018-12-30
    • 2021-01-10
    • 2018-10-06
    • 1970-01-01
    • 2019-12-20
    相关资源
    最近更新 更多