Optunaの各trial毎に決まった処理を実行させる方法

はじめに

最近、Optunaというハイパラ調整ライブラリを使っています。ハイパラ調整はある種職人的技量が求められるため、この部分で困っている人にはおすすめです。このOptunaですが、デフォルトの場合、ハイパラ探索の結果を見れるのは全計算が終わった後になります。全体の計算が5分程度なら良いのですが、何時間もかかる場合はモヤモヤしながら待つことになります。そこで、各trial毎(すなわち一つの探索毎)に決まった処理を実行させる方法を紹介します。

使用したバージョン

  • python: 3.8.11
  • optuna: 3.2.0

コード

通常の使い方

まずは、通常の使い方を紹介します。以下のコードは、ハイパラ探索の結果を見るために、全てのtrialが終わるまで待つコードです。

import optuna
# 100通りのハイパラを探索する
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
# ハイパラ探索の最良結果を表示
print(study.best_params)
print(study.best_value)
print(study.best_trial)

各trial毎に処理を実行させる使い方

各trial毎に処理を実行させるには、callbackを使います。以下のコードは、各trialが終わるたびに、そのtrialのハイパラ探索の結果を表示するコードです。

import optuna

class MyCallback():
    def __init__(self):
        super().__init__()
    def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
        # trial毎の処理を記述
        print(f'current_trial: {trial}')
        print(f'current_params: {trial.params}')
        print(f'current_value: {trial.value}')
        print('------------------')

# callbackを設定
my_callback=MyCallback()
# 100通りのハイパラを探索する
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100, callbacks=[my_callback])

まとめ

Optunaの各trial毎に決まった処理を実行させる方法を紹介しました。今回はprint出力を例にしましたが、例えば、各trial毎にハイパラ探索の結果をファイルに保存したり、tensorboadにリアルタイムで表示させたりすることもできます。 修正点やご意見などありましたら、コメントお願いします。