【问题标题】:split by values dataframe python按值数据框python拆分
【发布时间】:2021-09-01 00:34:48
【问题描述】:

我有这个数据框:

-----------------------------------------------------
|  age  |  gender  | customer type | purchases | id |
+-------+----------+---------------+-----------+----|
|  38   |  female  |   type 1      |    90     |  1 |
|  35   |  female  |   type 2      |   100     |  2 |
|  71   |  male    |   type 2      |    66     |  3 |
|  68   |  female  |   type 3      |    12     |  4 |
|  26   |  male    |   type 4      |    900    |  5 |
|  55   |  male    |   type 5      |    71     |  6 |
|  27   |  male    |   type 1      |    55     |  7 |
|  ...  |   ...    |    ...        |    ...    | ...|
+-------+----------+---------------+-----------+----+

我想对每种客户类型进行火车和测试的拆分,例如 20% 测试 80% 的火车,并且具有相似的年龄和性别分布,因为例如: 如果我得到它的类型 1,80% 的女性,这不是一个好的分裂。

我尝试使用带有种子的随机模块,但我无法获得它,因为我不知道如何考虑拆分的年龄和性别。

谢谢!!

【问题讨论】:

标签: python split


【解决方案1】:

如果您的数据库足够大,我不明白为什么随机抽取 20% 的数据库进行测试和另外 80% 的训练可以修改年龄和性别分布。这是一个小例子,我会怎么做:

#!/usr/bin/python
import numpy as np
# Generate database
N = 1000000 #size of the database
age = np.abs(np.random.randn(N)) * 30 # Normal distribution 
gender = np.random.randint(0, 100, N)<42 # 0=male and 1=female with a 42/58 repartition
customerType = np.random.randint(0, 6, N) # 5 types of customers
purchases = np.random.randint(0, 1000, N) 

# Split database in test and train, with test containing 20% of the db and train the other 80%
# Constraints: the test and train db should have the same gender and age distribution. 
testMask = np.random.randint(0, 100, N) < 20
trainMask = np.logical_not(testMask)


# check gender distribution
print("All database: %0.2f %% female %0.2f %% male" % (gender.mean()*100., (1-gender.mean())*100.))
print("Test database: %0.2f %% female %0.2f %% male" % (gender[testMask].mean()*100., (1-gender[testMask].mean())*100.))
print("Train database: %0.2f %% female %0.2f %% male" % (gender[trainMask].mean()*100., (1-gender[trainMask].mean())*100.))

# Check age distribution
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1)
ax.hist(age, bins=30, density=True, label="All db", alpha=0.3)
ax.hist(age[testMask], bins=30, density=True, label="Test db", alpha=0.3)
ax.hist(age[trainMask], bins=30, density=True, label="Train db", alpha=0.3)
ax.legend()
plt.show()

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-10-01
    • 2021-11-23
    • 2011-10-26
    • 1970-01-01
    相关资源
    最近更新 更多