Early stopping and timeout in optuna

I have previously explained how to run optuna in parallel. Here, I’d like to explain how one can apply timeout and early stopping.

First thing to understand is the use of Manager. It’s required to safely passing around variables in multiprocessing context. In the code below, success, pruneded, failure counters are passed so that multiprocessing optimization can access and update them safely.

        with Manager() as manager:
            self._start_time = time.time()
            self._n_success = manager.Value("i", 0)
            self._n_pruned = manager.Value("i", 0)
            self._n_failure = manager.Value("i", 0)

Actual use of those variables happens in a callback by callling optimize as below.

study.optimize(..., callbacks=[YourMonitoringClass(n_success=self._n_success, ...)]

Now, the monitoring class can receive those variables safely. And in your class’s __call__ , we call the functions below to decide early stop. Additionally, you can record # of success, failures, etc.

    def __call__(self, study: optuna.Study, trial: optuna.trial.FrozenTrial):
        match trial.state:
            case optuna.trial.TrialState.FAIL:
                self._n_failure.value += 1
                status_to_log = Status.FAILED
                log_level = logging.ERROR
            case optuna.trial.TrialState.PRUNED:
                self._n_pruned.value += 1
                status_to_log = Status.PRUNED
                log_level = logging.INFO
            case optuna.trial.TrialState.COMPLETE:
                self._n_success.value += 1
                status_to_log = Status.FINISHED
                log_level = logging.INFO
        n_finished = (
            self._n_success.value + self._n_pruned.value + self._n_failure.value
        )
        ... log progress ...
        self._may_timeout(study)
        self._may_early_stop(study)

    def _may_timeout(self, study: optuna.Study):
        if time.time() - self._start_time > self._timeout:
            ... log that this is timeout ...
            study.stop()

    def _may_early_stop(self, study: optuna.Study):
        if self._early_stopping_config is None:
            return
        trials = [t.value for t in study.trials if t.value is not None]
        if len(trials) < self._early_stopping_config.patience:
            return
        recent_trials = trials[-self._early_stopping_config.patience :]
        recent_std = np.std(recent_trials)
        if recent_std < self._early_stopping_config.threshold:
            ... log that this is early stop ...
            study.stop()

Also, keep in mind that you need to early stop using validation data.