Skip to content

Commit 8ad1361

Browse files
committed
Set early stopping tollerance
1 parent 757ac34 commit 8ad1361

4 files changed

Lines changed: 34 additions & 6 deletions

File tree

py-forust/forust/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def __init__(
339339
grow_policy: str = "DepthWise",
340340
evaluation_metric: str | None = None,
341341
early_stopping_rounds: int | None = None,
342+
early_stopping_delta: float = 1e-7,
342343
initialize_base_score: bool = True,
343344
terminate_missing_features: Iterable[Any] | None = None,
344345
missing_node_treatment: str = "None",
@@ -422,6 +423,8 @@ def __init__(
422423
early_stopping_rounds (int | None, optional): If this is specified, and an `evaluation_dataset` is passed
423424
during fit, then an improvement in the `evaluation_metric` must be seen after at least this many
424425
iterations of training, otherwise training will be cut short.
426+
early_stopping_delta (float, optional): Minimum improvement in the evaluation metric
427+
required to count as an improvement for early stopping. Defaults to 1e-7. Set to 0.0 to count any strict improvement.
425428
initialize_base_score (bool, optional): If this is specified, the `base_score` will be calculated at fit time using the `sample_weight` and y data in accordance with the requested `objective_type`. This will result in the passed `base_score` value being overridden.
426429
terminate_missing_features (set[Any], optional): An optional iterable of features (either strings, or integer values specifying the feature indices if numpy arrays are used for fitting), for which the missing node will always be terminated, even if `allow_missing_splits` is set to true. This value is only valid if `create_missing_branch` is also True.
427430
missing_node_treatment (str, optional): Method for selecting the `weight` for the missing node, if `create_missing_branch` is set to `True`. Defaults to "None". Valid options are:
@@ -516,6 +519,7 @@ def __init__(
516519
grow_policy=grow_policy,
517520
evaluation_metric=evaluation_metric,
518521
early_stopping_rounds=early_stopping_rounds,
522+
early_stopping_delta=early_stopping_delta,
519523
initialize_base_score=initialize_base_score,
520524
terminate_missing_features=set(),
521525
missing_node_treatment=missing_node_treatment,
@@ -556,6 +560,7 @@ def __init__(
556560
self.other_rate = other_rate
557561
self.evaluation_metric = evaluation_metric
558562
self.early_stopping_rounds = early_stopping_rounds
563+
self.early_stopping_delta = early_stopping_delta
559564
self.initialize_base_score = initialize_base_score
560565
self.terminate_missing_features = terminate_missing_features_
561566
self.missing_node_treatment = missing_node_treatment

py-forust/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ impl GradientBooster {
7878
grow_policy,
7979
evaluation_metric,
8080
early_stopping_rounds,
81+
early_stopping_delta,
8182
initialize_base_score,
8283
terminate_missing_features,
8384
missing_node_treatment,
@@ -111,6 +112,7 @@ impl GradientBooster {
111112
grow_policy: &str,
112113
evaluation_metric: Option<&str>,
113114
early_stopping_rounds: Option<usize>,
115+
early_stopping_delta: f64,
114116
initialize_base_score: bool,
115117
terminate_missing_features: HashSet<usize>,
116118
missing_node_treatment: &str,
@@ -157,6 +159,7 @@ impl GradientBooster {
157159
grow_policy_,
158160
evaluation_metric_,
159161
early_stopping_rounds,
162+
early_stopping_delta,
160163
initialize_base_score,
161164
terminate_missing_features,
162165
missing_node_treatment_,
@@ -420,6 +423,7 @@ impl GradientBooster {
420423
dict.set_item("grow_policy", grow_policy_)?;
421424
dict.set_item("evaluation_metric", evaluation_metric_)?;
422425
dict.set_item("early_stopping_rounds", self.booster.early_stopping_rounds)?;
426+
dict.set_item("early_stopping_delta", self.booster.early_stopping_delta)?;
423427
dict.set_item("initialize_base_score", self.booster.initialize_base_score)?;
424428
dict.set_item(
425429
"terminate_missing_features",

src/gradientbooster.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ pub struct GradientBooster {
159159
/// to keep training.
160160
#[serde(default = "default_early_stopping_rounds")]
161161
pub early_stopping_rounds: Option<usize>,
162+
/// Minimum improvement in the evaluation metric required to count as
163+
/// an improvement for early stopping purposes. Defaults to 1e-7 to
164+
/// match XGBoost's behavior.
165+
#[serde(default = "default_early_stopping_delta")]
166+
pub early_stopping_delta: f64,
162167
/// If this is specified, the base_score will be calculated using the sample_weight and y data in accordance with the requested objective_type.
163168
#[serde(default = "default_initialize_base_score")]
164169
pub initialize_base_score: bool,
@@ -221,6 +226,9 @@ fn default_evaluation_metric() -> Option<Metric> {
221226
fn default_early_stopping_rounds() -> Option<usize> {
222227
None
223228
}
229+
fn default_early_stopping_delta() -> f64 {
230+
1e-7
231+
}
224232
fn default_evaluation_history() -> Option<RowMajorMatrix<f64>> {
225233
None
226234
}
@@ -283,6 +291,7 @@ impl Default for GradientBooster {
283291
GrowPolicy::DepthWise,
284292
None,
285293
None,
294+
1e-7,
286295
true,
287296
HashSet::new(),
288297
MissingNodeTreatment::AssignToParent,
@@ -334,6 +343,7 @@ impl GradientBooster {
334343
/// * `sample_method` - Specify the method that records should be sampled when training?
335344
/// * `evaluation_metric` - Define the evaluation metric to record at each iterations.
336345
/// * `early_stopping_rounds` - Number of rounds that must
346+
/// * `early_stopping_delta` - Minimum improvement required to reset the early stopping counter.
337347
/// * `initialize_base_score` - If this is specified, the base_score will be calculated using the sample_weight and y data in accordance with the requested objective_type.
338348
/// * `missing_node_treatment` - specify how missing nodes should be handled during training.
339349
/// * `log_iterations` - Setting to a value (N) other than zero will result in information being logged about ever N iterations.
@@ -365,6 +375,7 @@ impl GradientBooster {
365375
grow_policy: GrowPolicy,
366376
evaluation_metric: Option<Metric>,
367377
early_stopping_rounds: Option<usize>,
378+
early_stopping_delta: f64,
368379
initialize_base_score: bool,
369380
terminate_missing_features: HashSet<usize>,
370381
missing_node_treatment: MissingNodeTreatment,
@@ -398,6 +409,7 @@ impl GradientBooster {
398409
grow_policy,
399410
evaluation_metric,
400411
early_stopping_rounds,
412+
early_stopping_delta,
401413
initialize_base_score,
402414
terminate_missing_features,
403415
evaluation_history: None,
@@ -662,7 +674,7 @@ impl GradientBooster {
662674
// Otherwise the best could be farther back.
663675
Some(v) => {
664676
// We have reached a new best value...
665-
if is_comparison_better(v, m, maximize) {
677+
if is_comparison_better(v, m, maximize, self.early_stopping_delta) {
666678
self.update_best_iteration(i);
667679
Some(m)
668680
} else {
@@ -1264,6 +1276,13 @@ impl GradientBooster {
12641276
self
12651277
}
12661278

1279+
/// Set the minimum improvement delta for early stopping.
1280+
/// * `early_stopping_delta` - Minimum improvement required.
1281+
pub fn set_early_stopping_delta(mut self, early_stopping_delta: f64) -> Self {
1282+
self.early_stopping_delta = early_stopping_delta;
1283+
self
1284+
}
1285+
12671286
/// Set prediction iterations.
12681287
/// * `early_stopping_rounds` - Early stoppings rounds.
12691288
pub fn set_prediction_iteration(mut self, prediction_iteration: Option<usize>) -> Self {

src/metric.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub type MetricFn = fn(&[f64], &[f64], &[f64]) -> f64;
99
/// Compare to metric values, determining if b is better.
1010
/// If one of them is NaN favor the non NaN value.
1111
/// If both are NaN, consider the first value to be better.
12-
pub fn is_comparison_better(value: f64, comparison: f64, maximize: bool) -> bool {
12+
pub fn is_comparison_better(value: f64, comparison: f64, maximize: bool, delta: f64) -> bool {
1313
match (value.is_nan(), comparison.is_nan()) {
1414
// Both nan, comparison is not better,
1515
// Or comparison is nan, also not better
@@ -19,13 +19,13 @@ pub fn is_comparison_better(value: f64, comparison: f64, maximize: bool) -> bool
1919
// Perform numerical comparison.
2020
(false, false) => {
2121
// If we are maximizing is the comparison
22-
// greater, than the current value
22+
// greater, than the current value by at least delta
2323
if maximize {
24-
value < comparison
24+
comparison > value + delta
2525
// If we are minimizing is the comparison
26-
// less than the current value.
26+
// less than the current value by at least delta
2727
} else {
28-
value > comparison
28+
comparison < value - delta
2929
}
3030
}
3131
}

0 commit comments

Comments
 (0)