概述:基于pycaret进行简易数据分析
相关文档
- GitHub 地址:https://github.com/pycaret/pycaret
- 用户文档:https://www.pycaret.org/guide
- Notebook 教程:https://www.pycaret.org/tutorial
模块
分类、回归、聚类、异常检测、自然语言处理和关联规则挖掘
Install
1 | pip install pycaret |
示例数据
以pycaret自带的数据集中糖尿病数据为例(diabetes)
1 | from pycaret.datasets import get_data |
该问题属于分类问题,因此导入分类算法模块;明确数据集及标签名,并对数据集进行归一化开始部署过程
1 | from pycaret.classification import * |
上述代码完成后会生成属性-类型,若识别正确回车即可
分类器
使用库中分类器
1 | compare_models() |
使用具体分类器
以逻辑回归为例,针对分类问题中各个模型的简称个人总结
模型 | 函数 | 简称 |
---|---|---|
决策树 | decision tree | dt |
岭回归 | Ridge Classifier | ridge |
逻辑回归 | Logistic Regression | lr |
线性判别分析 | Linear Discriminant Analysis | lda |
Xgboost | Extreme Gradient Boosting | xgboost |
CatBoost | CatBoost Classifier | catboost |
极端随机树 | Extra Trees Classifier | et |
朴素贝叶斯 | Naive Bayes | nb |
随机森林 | Random Forest Classifier | rf |
梯度提升 | Gradient Boosting Classifier | gbc |
AdaBoost | Ada Boost Classifier | ada |
支持向量机 | SVM - Linear Kernel | svm |
LGBM | Light Gradient Boosting Machine | lightgbm |
K近邻 | K Neighbors Classifier | knn |
二次判别分析 | Quadratic Discriminant Analysis | qda |
总结完后发现有文档(哭晕在厕所):https://pycaret.org/create-model/
1 | lr = create_model('lr') |
对具体分类器调整优化,以逻辑回归为例
1 | tuned_lr = tune_model('lr') |
集成模型
以决策树为例
1 | dt = create_model('dt') |
(1)集成多个不同的模型(以逻辑回归、决策树、LGBM为例)
1 | lightgbm = create_model('lightgbm') |
(2)Stacking集成多个模型(以岭回归、线性判别分析、梯度提升、xgboost为例)
1 | # 创建单个模型,用于stacking |
绘制
绘制AUC曲线,以逻辑回归为例
1 | lr = create_model('lr') |
绘制决策分界线
1 | plot_model(lr, plot = 'boundary') |
绘制PR曲线
PR曲线(Precision Recall Curve)
1 | plot_model(lr, plot = 'pr') |
绘制验证曲线
Validation Curve
1 | plot_model(lr, plot = 'vc') |
模型解释
SHAP
1 | interpret_model(xgboost) |
correlation
1 | interpret_model(xgboost, plot = 'correlation') |
reason
1 | interpret_model(xgboost, plot = 'reason', observation = 0) |
预测样本
1 | rf_holdout_pred = predict_model(rf) |
完成模型构建
1 | final_rf = finalize_model(rf) |
保存模型
以二进制格式
1 | save_model(adaboost, model_name = 'ada_for_deployment') |
模型部署
1 | deploy_model(final_lr, model_name = 'lr_aws', platform = 'aws', authentication = { 'bucket' : 'pycaret-test' }) |