Skip to content

Commit d6ec7ec

Browse files
izeigermankrinart
authored andcommitted
Fix application of best_ntree_limit to the entire list of estimators. Instead the limit is applied to per-class estimators split (#83)
1 parent 8610a18 commit d6ec7ec

File tree

2 files changed

+94
-9
lines changed

2 files changed

+94
-9
lines changed

m2cgen/assemblers/boosting.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@ class BaseBoostingAssembler(ModelAssembler):
99

1010
classifier_name = None
1111

12-
def __init__(self, model, trees, base_score=0):
12+
def __init__(self, model, trees, base_score=0, tree_limit=None):
1313
super().__init__(model)
1414
self.all_trees = trees
1515
self._base_score = base_score
1616

1717
self._output_size = 1
1818
self._is_classification = False
1919

20+
assert tree_limit is None or tree_limit > 0, "Unexpected tree limit"
21+
self._tree_limit = tree_limit
22+
2023
model_class_name = type(model).__name__
2124
if model_class_name == self.classifier_name:
2225
self._is_classification = True
@@ -34,6 +37,9 @@ def assemble(self):
3437
self.all_trees, self._base_score)
3538

3639
def _assemble_single_output(self, trees, base_score=0):
40+
if self._tree_limit:
41+
trees = trees[:self._tree_limit]
42+
3743
trees_ast = [self._assemble_tree(t) for t in trees]
3844
result_ast = utils.apply_op_to_expressions(
3945
ast.BinNumOpType.ADD,
@@ -83,16 +89,14 @@ def __init__(self, model):
8389
}
8490

8591
model_dump = model.get_booster().get_dump(dump_format="json")
86-
87-
# Respect XGBoost ntree_limit
88-
ntree_limit = getattr(model, "best_ntree_limit", 0)
89-
90-
if ntree_limit > 0:
91-
model_dump = model_dump[:ntree_limit]
92-
9392
trees = [json.loads(d) for d in model_dump]
9493

95-
super().__init__(model, trees, base_score=model.base_score)
94+
# Limit the number of trees that should be used for
95+
# assembling (if applicable).
96+
best_ntree_limit = getattr(model, "best_ntree_limit", None)
97+
98+
super().__init__(model, trees, base_score=model.base_score,
99+
tree_limit=best_ntree_limit)
96100

97101
def _assemble_tree(self, tree):
98102
if "leaf" in tree:

tests/assemblers/test_xgboost.py

+81
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,84 @@ def test_regression_best_ntree_limit():
147147
ast.BinNumOpType.ADD))
148148

149149
assert utils.cmp_exprs(actual, expected)
150+
151+
152+
def test_multi_class_best_ntree_limit():
153+
base_score = 0.5
154+
estimator = xgboost.XGBClassifier(n_estimators=100, random_state=1,
155+
max_depth=1, base_score=base_score)
156+
157+
estimator.best_ntree_limit = 1
158+
159+
utils.train_model_classification(estimator)
160+
161+
assembler = assemblers.XGBoostModelAssembler(estimator)
162+
actual = assembler.assemble()
163+
164+
estimator_exp_class1 = ast.ExpExpr(
165+
ast.SubroutineExpr(
166+
ast.BinNumExpr(
167+
ast.NumVal(0.5),
168+
ast.IfExpr(
169+
ast.CompExpr(
170+
ast.FeatureRef(2),
171+
ast.NumVal(2.5999999),
172+
ast.CompOpType.GTE),
173+
ast.NumVal(-0.0731707439),
174+
ast.NumVal(0.142857149)),
175+
ast.BinNumOpType.ADD)),
176+
to_reuse=True)
177+
178+
estimator_exp_class2 = ast.ExpExpr(
179+
ast.SubroutineExpr(
180+
ast.BinNumExpr(
181+
ast.NumVal(0.5),
182+
ast.IfExpr(
183+
ast.CompExpr(
184+
ast.FeatureRef(2),
185+
ast.NumVal(2.5999999),
186+
ast.CompOpType.GTE),
187+
ast.NumVal(0.0341463387),
188+
ast.NumVal(-0.0714285821)),
189+
ast.BinNumOpType.ADD)),
190+
to_reuse=True)
191+
192+
estimator_exp_class3 = ast.ExpExpr(
193+
ast.SubroutineExpr(
194+
ast.BinNumExpr(
195+
ast.NumVal(0.5),
196+
ast.IfExpr(
197+
ast.CompExpr(
198+
ast.FeatureRef(2),
199+
ast.NumVal(4.85000038),
200+
ast.CompOpType.GTE),
201+
ast.NumVal(0.129441619),
202+
ast.NumVal(-0.0681440532)),
203+
ast.BinNumOpType.ADD)),
204+
to_reuse=True)
205+
206+
exp_sum = ast.BinNumExpr(
207+
ast.BinNumExpr(
208+
estimator_exp_class1,
209+
estimator_exp_class2,
210+
ast.BinNumOpType.ADD),
211+
estimator_exp_class3,
212+
ast.BinNumOpType.ADD,
213+
to_reuse=True)
214+
215+
expected = ast.VectorVal([
216+
ast.BinNumExpr(
217+
estimator_exp_class1,
218+
exp_sum,
219+
ast.BinNumOpType.DIV),
220+
ast.BinNumExpr(
221+
estimator_exp_class2,
222+
exp_sum,
223+
ast.BinNumOpType.DIV),
224+
ast.BinNumExpr(
225+
estimator_exp_class3,
226+
exp_sum,
227+
ast.BinNumOpType.DIV)
228+
])
229+
230+
assert utils.cmp_exprs(actual, expected)

0 commit comments

Comments
 (0)