Optuna in multiprocessing of python

Previously, I posted how to parallelize optuna. Back then, I didn’t realize that there’s difficult to debug race conditions such as concurrent database access. My method has changed since then, so I’m posting it again.

The basic structure of the approach is writing a loop to get parameters, submit tasks, and collect results in batches. Note that optuna’s optimize() isn’t parallelized but the training algorithms are.

  • Use study.ask() to get parameters to test
  • Submit tasks to executor for running your own training algorithms
  • Collect the results and use study.tell() to tell optuna the result

This is pretty straightforward, right? It looks so. But soon you’ll find complexity.

suggest_xxx funcs may propose parameters that were previously proposed. This can be solved as explained in optuna FAQ. First thing to do is finding duplicate params.

        trials_to_consider = study.get_trials(
            deepcopy=False,
            states=(TrialState.COMPLETE, TrialState.PRUNED, TrialState.FAIL),
        )
        for prev_trial in reversed(trials_to_consider):
            if trial.params == prev_trial.params:
                return prev_trial

After that, report previous state. This should happen between ask() and submitting tasks to executors. They shouldn’t be submitting tasks, obviously.

if dupe_trial.state == TrialState.COMPLETE:
    study.tell(trial, values=dupe_trial.values)
else:
    study.tell(trial, state=dupe_trial.state)

Another issue is that suggest_xxx funcs may propose parameters that are already submitted to executor if they didn’t finish yet (hence not detected by duplicate check). This means that you should not blindly submit algorithm + params to executors. Make sure to find already running future, store trial, and call study.tell() when the future completes instead of running same params multiple times.

Finally, when using timeout, I do not forcefully finish processes as they causes one trouble from another. Instead, use ValueProxy. Pass that to training algorithm.

with Manager() as manager:
    timeout_latch = manager.Value("b", False)

When timeout is detected in your optimization loop, set the latch to True. ML algorithms’ for loop should be checking that latch and raise Timeout().