1.介绍

gcForest v1.1.1是gcForest的一个官方托管在GitHub上的版本,是由Ji Feng(Deep Forest的paper的作者之一)维护和开发,该版本支持Python3.5,且有类似于Scikit-Learn的API接口风格,在该项目中提供了一些调用例子,目前支持的基分类器有RandomForestClassifier,XGBClassifer,ExtraTreesClassifier,LogisticRegression,SGDClassifier如果采用XGBoost的基分类器还可以使用GPU,如果想增加其他基分类器,可以在模块中的lib/gcforest/estimators/__init__.py中添加,使用该模块需要依赖安装如下模块:

  • argparse
  • joblib
  • keras
  • psutil
  • scikit-learn>=0.18.1
  • scipy
  • simplejson
  • tensorflow
  • xgboost

2.API调用样例

这里先列出gcForest提供的API接口:

  • fit_tranform(X_train,y_train) 是gcForest模型最后一层每个估计器预测的概率concatenated的结果

  • fit_transform(X_train,y_train,X_test=x_test,y_test=y_test) 测试数据的准确率在训练的过程中也会被记录下来

  • set_keep_model_mem(False) 如果你的缓存不够,把该参数设置成False(默认为True),如果设置成False,你需要使用fit_transform(X_train,y_train,X_test=x_test,y_test=y_test)来评估你的模型

  • predict(X_test) # 模型预测

  • transform(X_test)

最简单的调用gcForest的方式如下:


# 导入必要的模块
from gcforest.gcforest import GCForest

# 初始化一个gcForest对象
gc = GCForest(config) # config是一个字典结构

# gcForest模型最后一层每个估计器预测的概率concatenated的结果
X_train_enc = gc.fit_transform(X_train,y_train)

# 测试集的预测
y_pred = gc.predict(X_test)

下面我们使用MNIST数据集来演示gcForest的使用及代码的详细说明:

# 导入必要的模块

import argparse # 命令行参数调用模块
import numpy as np 
import sys
from keras.datasets import mnist # MNIST数据集
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
sys.path.insert(0, "lib")

from gcforest.gcforest import GCForest
from gcforest.utils.config_utils import load_json


def parse_args():
	'''
	解析终端命令行参数(model)
	'''
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", dest="model", type=str, default=None, 
	help="gcfoest Net Model File")
    args = parser.parse_args()
    return args


def get_toy_config():
	'''
	生成级联结构的相关结构
	'''
    config = {}
    ca_config = {}
    ca_config["random_state"] = 0
    ca_config["max_layers"] = 100
    ca_config["early_stopping_rounds"] = 3
    ca_config["n_classes"] = 10
    ca_config["estimators"] = []
    ca_config["estimators"].append(
            {"n_folds": 5, "type": "XGBClassifier", "n_estimators": 10, 
		"max_depth": 5,"objective": "multi:softprob", "silent": 
		True, "nthread": -1, "learning_rate": 0.1} )
    ca_config["estimators"].append({"n_folds": 5, "type": "RandomForestClassifier", 
	"n_estimators": 10, "max_depth": None, "n_jobs": -1})
    ca_config["estimators"].append({"n_folds": 5, "type": "ExtraTreesClassifier",
	 "n_estimators": 10, "max_depth": None, "n_jobs": -1})
    ca_config["estimators"].append({"n_folds": 5, "type": "LogisticRegression"})
    config["cascade"] = ca_config
    return config

相关文章:

  • 2021-06-16
  • 2021-04-17
  • 2021-06-25
  • 2021-10-03
  • 2022-03-07
  • 2022-12-23
  • 2022-12-23
猜你喜欢
  • 2021-04-11
  • 2021-06-21
  • 2022-12-23
  • 2021-12-12
  • 2021-08-24
  • 2021-05-11
  • 2022-12-23
相关资源
相似解决方案