Quantcast
Channel: InfoQ - 促进软件开发领域知识与创新的传播
Viewing all articles
Browse latest Browse all 1638

XGBoost缺失值引发的问题及其深度分析

$
0
0

背景

XGBoost模型作为机器学习中的一大“杀器”,被广泛应用于数据科学竞赛和工业领域,XGBoost官方也提供了可运行于各种平台和环境的对应代码,如适用于Spark分布式训练的XGBoost on Spark。然而,在XGBoost on Spark的官方实现中,却存在一个因XGBoost缺失值和Spark稀疏表示机制而带来的不稳定问题。

事情起源于美团内部某机器学习平台使用方同学的反馈,在该平台上训练出的XGBoost模型,使用同一个模型、同一份测试数据,在本地调用(Java引擎)与平台(Spark引擎)计算的结果不一致。但是该同学在本地运行两种引擎(Python引擎和Java引擎)进行测试,两者的执行结果是一致的。因此质疑平台的XGBoost预测结果会不会有问题?

该平台对XGBoost模型进行过多次定向优化,在XGBoost模型测试时,并没有出现过本地调用(Java引擎)与平台(Spark引擎)计算结果不一致的情形。而且平台上运行的版本,和该同学本地使用的版本,都来源于Dmlc的官方版本,JNI底层调用的应该是同一份代码,理论上,结果应该是完全一致的,但实际中却不同。

从该同学给出的测试代码上,并没有发现什么问题:

//测试结果中的一行,41列
double[] input = new double[]{1, 2, 5, 0, 0, 6.666666666666667, 31.14, 29.28, 0, 1.303333, 2.8555, 2.37, 701, 463, 3.989, 3.85, 14400.5, 15.79, 11.45, 0.915, 7.05, 5.5, 0.023333, 0.0365, 0.0275, 0.123333, 0.4645, 0.12, 15.082, 14.48, 0, 31.8425, 29.1, 7.7325, 3, 5.88, 1.08, 0, 0, 0, 32];
//转化为float[]
float[] testInput = new float[input.length];
for(int i = 0, total = input.length; i < total; i++){
  testInput[i] = new Double(input[i]).floatValue();
}
//加载模型
Booster booster = XGBoost.loadModel("${model}");
//转为DMatrix,一行,41列
DMatrix testMat = new DMatrix(testInput, 1, 41);
//调用模型
float[][] predicts = booster.predict(testMat);

上述代码在本地执行的结果是333.67892,而平台上执行的结果却是328.1694030761719。


Viewing all articles
Browse latest Browse all 1638

Trending Articles