【问题标题】:How can I ensure that the users and items appear in both train and test data set with train_test_split in sklearn?如何使用 sklearn 中的 train_test_split 确保用户和项目同时出现在训练和测试数据集中?
【发布时间】:2016-05-25 10:37:40
【问题描述】:

我有一个数据集,包括user IDitem IDrating,如下所示:

user ID     item ID    rating
 1233        1011       4
 1220        0999       3
 2011        0702       1
 ...

当我将它们分成traintest 集时:

from sklearn import cross_validation

train, test = cross_validation.train_test_split(df, test_size = 0.2)

测试集中的用户是否已经出现在训练集中,是否有项目?如果没有,我该怎么做?我在document 中找不到答案。你能告诉我吗?

【问题讨论】:

  • 我不明白这个问题。你到底想做什么?
  • @kazemakase 该模型是在测试集中预测ratinguseritem。为此,我们必须测量训练集中useritem 的潜在因子。那么如何确保测试集中的用户也在训练集中。因为,同样的事情应该发生在项目中。好点了吗?
  • 我也不太明白你在问什么。您想对用户、项目或用户和项目的独特组合进行分层吗?例如,您是否允许您的训练和测试分区都包含用户 X 对不同项目的排名,或者两者都包含不同用户对项目 Y 的排名?是否可以同时包含用户 X 和项目 Y 的示例,只要它们都不包含用户 X 对项目 Y 的评分?
  • @ali_m 这就是我的意思:allow your training and test partitions to both contain rankings of different items by user X, or for both to contain rankings of item Y by different usersOK for them both to contain examples of user X and item Y as long as they don't both contain a rating of item Y by user X

标签: python pandas machine-learning scikit-learn cross-validation


【解决方案1】:

如果您想确保您的训练和测试分区不包含相同的用户和项目配对,那么您可以将每个唯一的(用户、项目)组合替换为整数标签,然后将这些标签传递给 LabelKFold .要为每个唯一配对分配整数标签,您可以使用this trick

import numpy as np
import pandas as pd
from sklearn.cross_validation import LabelKFold

df = pd.DataFrame({'users':[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
                   'items':[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
                   'ratings':[2, 4, 3, 1, 4, 3, 0, 0, 0, 1, 0, 1]})

users_items = df[['users', 'items']].values
d = np.dtype((np.void, users_items.dtype.itemsize * users_items.shape[1]))
_, uidx = np.unique(np.ascontiguousarray(users_items).view(d), return_inverse=True)

for train, test in LabelKFold(uidx):

    # train your classifier using df.loc[train, ['users', 'items']] and
    # df.loc[train, 'ratings']...

    # cross-validate on df.loc[test, ['users', 'items']] and
    # df.loc[test, 'ratings']...

我仍然很难理解您的问题。如果您想保证您的训练集和测试集确实包含同一用户的示例,那么您可以使用StratifiedKFold:

for train, test in StratifiedKFold(df['users']):
    # ...

【讨论】:

  • 对不起,我不担心训练集和测试集包含相同的用户和项目对。我担心用户出现在测试集中没有在训练集中被测量。
  • 查看我的编辑。我发现很难推断出你想要什么,因为你的问题措辞非常含糊。
【解决方案2】:
def train_test_split(self, ratings, train_rate=0.8):
        """
        Split ratings into Training set and Test set

        """
        grps = ratings.groupby('user_id').groups
        test_df_index = list()
        train_df_index = list()

        test_iid = list()
        train_iid = list()

        for key in grps:
            count = 0
            local_index = list()
            grp = np.array(list(grps[key]))

            n_test = int(len(grp) * (1 - train_rate))
            for i, index in enumerate(grp):
                if count >= n_test:
                    break
                if ratings.iloc[index]['movie_id'] in test_iid:
                    continue
                test_iid.append(ratings.iloc[index]['movie_id'])
                test_df_index.append(index)
                local_index.append(i)
                count += 1

            grp = np.delete(grp, local_index)

            if count < n_test:
                local_index = list()
                for i, index in enumerate(grp):
                    if count >= n_test:
                        break
                    test_iid.append(ratings.iloc[index]['movie_id'])
                    test_df_index.append(index)
                    local_index.append(i)
                    count += 1

                grp = np.delete(grp, local_index)

            train_df_index.append(grp)

        test_df_index = np.hstack(np.array(test_df_index))
        train_df_index = np.hstack(np.array(train_df_index))

        np.random.shuffle(test_df_index)
        np.random.shuffle(train_df_index)

        return ratings.iloc[train_df_index], ratings.iloc[test_df_index]

你可以使用这种方法进行拆分,我已经努力确保训练集和测试集具有相同的用户id和电影id。

【讨论】:

    猜你喜欢
    • 2022-12-10
    • 2020-03-09
    • 2019-08-01
    • 2021-01-15
    • 2019-03-21
    • 2018-12-26
    • 2018-05-21
    • 2023-03-12
    • 2013-01-04
    相关资源
    最近更新 更多