如果您不打算进行任何花哨的交叉验证,并且基本上希望将数据分成训练集和测试集,以下是一种方法。
注意事项:我正在随机生成数据,因此我可能正在解决您不会遇到的问题,例如数据类型转换。
要点:事先手动打乱您的数据,使用train_test_split(来自 sklearn)并设置shuffle=False,然后对 numpy 数据集进行切片以切出字符串列并将数字传递给分类器。切片让这个例子看起来很难看,这就是我个人使用 pandas 的原因。
替代方法:稍微高级一些,但您可以使用管道和自定义转换器从整体数据中选择数字数据(未在下面显示)。
import numpy as np
from scipy.stats import bernoulli
import string, random
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
### making up data, don't need to understand this part ###
size = 10000
fnames = np.array([''.join(random.choices(string.ascii_letters, k=5)) for i in range(size)])
lnames = np.array([''.join(random.choices(string.ascii_letters, k=5)) for i in range(size)])
identifier = np.array([''.join(random.choices(string.ascii_letters, k=2)) for i in range(size)])
rand_ints = np.random.randint(0, 20, (size,10))
data = np.column_stack((fnames, lnames, identifier, rand_ints))
answers = bernoulli.rvs(0.3, size=size)
###---------------------------end----------------------###
np.random.shuffle(data) # an in-place operation
X_train, X_test, y_train, y_test = train_test_split(
data, answers, test_size=0.10, random_state=42, shuffle=False)
clf = MLPClassifier(solver='lbfgs', alpha=1e-5)
clf.fit(X_train[:, 3:].astype('int16'), y_train)
score = clf.score(X_test[:, 3:].astype('int16'), y_test)
print('Test score: {}'.format(score))
pred = clf.predict(X_test[:, 3:].astype('int16'))
X_train_result = np.column_stack((X_train[:, :3], y_train, clf.predict(X_train[:, 3:].astype('int16'))))
X_test_result = np.column_stack((X_test[:, :3], y_test, clf.predict(X_test[:, 3:].astype('int16'))))
上面例子的输出:
Test score: 0.694
>>> X_train_result
array([['TGOxK', 'dOKWj', 'Bn', '1', '0'],
['GmqwM', 'iucDx', 'qX', '1', '0'],
['VXdJG', 'SJRVg', 'Nl', '1', '0'],
...,
['jClSD', 'ABkrp', 'zZ', '0', '0'],
['IoLrh', 'HiHLI', 'oU', '1', '0'],
['zzyGR', 'UCpRT', 'xg', '0', '0']], dtype='<U11')
>>> X_test_result
array([['zOLmZ', 'OMVrx', 'AS', '1', '0'],
['wfIsi', 'zEMEE', 'PU', '1', '0'],
['wHVtq', 'fbtMK', 'UD', '1', '0'],
...,
['paBoM', 'HVjpF', 'Ez', '0', '0'],
['ZivWN', 'VrHhm', 'FL', '0', '0'],
['WnHLw', 'hakoK', 'Qv', '0', '0']], dtype='<U11')
- 第一列:名字
- 第二列:姓氏
- 第三列:标识符
- 第四栏:正确答案(y true)
- 第五列:预测答案(y pred)