Distributed Tensorflowに入門してみました。バージョンはr1.4です。 実際のコードを描いた上で基本的に使うであろうAPIについて書いていきたいと思います。

Distributed Tensorflow

分散学習をざっくり説明すると、各Workerに学習データを与え、それぞれで目的関数の勾配を求めます。それらの値を用いて共通のモデルパラメータを更新する感じ(Data Parallelize)です。

Example

以下のサンプルコードを上から辿るように説明していきます。モデルとしては2値分類のロジスティック回帰をasynchronous(Downpour)のAdaGradで最適化させています。

def main(args):
    _X, _y = make_classification(n_samples=10000, n_features=20, n_informative=5,
                                 n_redundant=2, n_classes=2, n_clusters_per_class=10,
                                 random_state=1)
    _y = np.reshape(_y, [_y.shape[0], 1])
    X_train, X_test, y_train, y_test = train_test_split(_X, _y, test_size=0.2)

    # Attempting to connect all nodes in `tf.train.ClusterSpec`.
    cluster_spec = tf.train.ClusterSpec({
        'ps': ['ps1:2222'],
        'worker': [
            'worker1:2222',
            'worker2:2222',
            'worker3:2222'
        ]
    })

    server = tf.train.Server(cluster_spec,
                             job_name=args.job_name,
                             task_index=args.task_index)

    if args.job_name == "ps":
        # `server.join()` means it's NEVER killed
        server.join()
    else:
        # Store the variables into `ps`.
        with tf.device(tf.train.replica_device_setter(
                cluster=cluster_spec)):

            global_step = tf.train.create_global_step()
            get_global_step = tf.train.get_global_step()

            # Hyper-parameters
            learning_rate = 0.01
            epoch = 1000

            global_step = tf.train.create_global_step()
            get_global_step = tf.train.get_global_step()

            # Binary Logistic Regression
            x = tf.placeholder(tf.float32, [None, X_train.shape[1]])
            y = tf.placeholder(tf.float32, [None, 1])

            w = tf.Variable(tf.zeros([X_train.shape[1], 1]), name='w')
            b = tf.Variable(tf.zeros([1]), name='b')

            logits = tf.matmul(x, w) + b
            pred = tf.sigmoid(logits)

            l = tf.constant(0.001)
            l2_norm = l * tf.nn.l2_loss(w)

            # cross_entropy = -1. * y * tf.log(pred) - (1. - y) * tf.log(1. - pred)
            loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=y)) + l2_norm
            optimizer = tf.train.AdagradOptimizer(learning_rate)
            train_op = optimizer.minimize(loss, global_step=global_step)

            is_chief = (args.task_index == 0)
            
            hooks = [tf.train.StopAtStepHook(last_step=epoch),
                     tf.train.CheckpointSaverHook('./example-save',
                                                  save_steps=epoch,
                                                  saver=tf.train.Saver(max_to_keep=1))]

            # Initialize the variables, if `is_chief`.
            with tf.train.MonitoredTrainingSession(
                    is_chief=is_chief,
                    master=server.target,
                    hooks=hooks) as sess:

                while not sess.should_stop():
                    step, _, train_loss, train_accuracy = sess.run([get_global_step, train_op, loss, accuracy],
                                                                   feed_dict={x: X_train, y: y_train})
                    print('In {step} step: loss = {loss}, accuracy = {accuracy}'
                          .format(step=step, loss=train_loss, accuracy=train_accuracy))

        # It's able to fetch variables in another session.
        if is_chief:
            with tf.Session(server.target) as sess:
                w, b, acc = sess.run([w, b, accuracy],
                                     feed_dict={x: X_test, y: y_test})
                print('accuracy in validation = ', acc)
                print('bias = ', b)
                print('weight = ', w)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--job-name', dest='job_name', type=str,
                        choices=['ps', 'worker'])
    parser.add_argument('--task-index', dest='task_index', type=int,
                        default=0)
    main(parser.parse_args())

Clusterのセットアップ

cluster_spec = tf.train.ClusterSpec({
    'ps': ['ps1:2222'],
    'worker': [
        'worker1:2222',
        'worker2:2222',
        'worker3:2222'
    ]
})

server = tf.train.Server(cluster_spec,
                         job_name=args.job_name,
                         task_index=args.task_index)

ClusterSpecでclusterの設定を行います。各ノードの役割としては以下の2通りがあります。

  • ps: Parameter Server。Variableの管理とWorkerで求めた勾配を用いてモデルパラメータを更新する役割を持ちます。
  • worker: Workerはlossにおける勾配の計算を行います。

Serverで自身の役割(job_name = ps or worker)とtask_indexを与えることでセットアップは完了します。 task_indexps, workerについてそれぞれ0からインクリメントしたユニークな値をノードに与えます。

つまりは、ノードを立ち上げる際にはクラスター全体の構成要素(ホストとポート)と自身がどれを司るかを与えてやる必要があります。

Tips

Cluster間のやりとりには基本的にgRPC (protobuf)が用いられていますが、Serverの引数としてprotocolなるものが存在しており(デフォルトはもちろんgrpc)、grpc+mpi, grpc+gdr, grpc+verbsが選択できるようです。

また、ここに記載したノードに対しては一定間隔でHealthcheckを行なっており、それぞれがきちんと疎通できていないと以下のOperationは一切動作しません。

疎通を保証したいホストを制限したい場合はtf.ConfigProtodevice_filtersを与える必要があります。このconfigはSessionを立ち上げる際の引数で与えます。

config = tf.ConfigProto(
        # Ignore the nodes which do not match filters, on each worker node.
        # Raise `InvalidArgumentError`, if not set `ps`.
        device_filters=["/job:ps", "/job:worker/task:%d" % args.task_index]
)

Parameter Server

if args.job_name == "ps":
        server.join()

Parameter Serverはjoinしておけばokですが、学習が終わっても自動的に終了しません。Issueにもあげられています。

Worker

VariablesをParameter Serverへ配置

with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):

Distributed Tensorflowにはdeviceという概念が存在します。tf.device('/job:worker/replica:0/task:0/cpu:0')のように指定することで、各Worker, Parameter Serverに対してどういったOperationを行わせるのかを記述できます。ただ、今回は1つのWorkerで特有なOperationは出てこない(どのWorkerも勾配を計算してParameter Serverに渡す)ので、深掘りはしません。

では、ここでは何をしているのかというと、このwith構文以下のVariableがParameter Serverに配置されるようにしています。配置の仕方はデフォルトでラウンドロビンに配置されます。配置の方法(ps_strategy)はカスタマイズできます。もちろん初期化はされません。

global_step

global_step = tf.train.create_global_step()
get_global_step = tf.train.get_global_step()

loss =...
optimizer = tf.train.AdagradOptimizer(learning_rate)
train_op = optimizer.minimize(loss, global_step=global_step)

asynchronousとは言えど、epochは同期的に共通のglobal_stepで管理されています。Worker毎にoptimizerminimizeをcallする毎にインクリメントされます。各Workerで勾配が求まった時点で、モデルパラメータが非同期に更新されていきます。

ちなみに、全て(一部)のWorkerの計算結果について同期(勾配の平均をとってからモデルパラメータの更新)したい場合はSyncReplicasOptimizerをさらに利用する必要があります。

Sessionの立ち上げ

is_chief = (args.task_index == 0)

hooks = [tf.train.StopAtStepHook(last_step=epoch),
         tf.train.CheckpointSaverHook('./example-save',
                                      save_steps=epoch,
                                      saver=tf.train.Saver(max_to_keep=1))]

# Initialize the variables, if `is_chief`.
with tf.train.MonitoredTrainingSession(
        is_chief=is_chief,
        master=server.target,
        hooks=hooks) as sess:

Distributed Tensorflowも同様にSessionを立ち上げます。SupervisorはそろそろdeprecatedになるようでMonitoredSessionを利用した記述にしました。

MonitoredTrainingSessionでSessionを立ち上げています。このメソッドを利用してSessionを立ち上げた場合は、学習前の下準備を内部で行います。

is_chiefが:

  • Trueの場合はreplica_device_setterで配置したVariableとglobal_stepの初期化を行います。
  • Falseの場合はchiefが初期化作業を終えるまで規定時間(ハードコーディングされている)待ったあと、疎通の確認を行います。

内部実装的にはChiefSessionCreatorWorkerSessionCreatorが生成され、それぞれのcreate_sessionで上記の役割を満たすように記述されているようです。また、ノード間の疎通の管理はどちらもSessionManagerが担っているようです。

hooksはepoch毎に追加で行いたいOperationやepochそのものの値を設定できます。自作のHookも作成できます。ベースにはSessionRunHookというクラスがあるので、満たしたい箇所のみ書いておけば適応されます。

今回利用したHookは以下です:

  • StopAtStepHook: last_step分だけ学習を回します。何回学習したかの記録は先述の通りglobal_stepの値を利用します。
  • CheckpointSaverHook: save_step毎に第一引数のディレクトリにモデルパラメータを保存します。

学習

while not sess.should_stop():
    step, _, train_loss, train_accuracy = sess.run([get_global_step, train_op, loss, accuracy],
                                                   feed_dict={x: X_train, y: y_train})
    print('In {step} step: loss = {loss}, accuracy = {accuracy}'
          .format(step=step, loss=train_loss, accuracy=train_accuracy))

should_stopはSessionが終わる状態になるまで、loopを回すことを意味します。この制御は基本的にHookで行います。今回はStopAtStepHookでそれを行なっています。

後始末

if is_chief:
    with tf.Session(server.target) as sess:
        w, b, acc = sess.run([w, b, accuracy],
                             feed_dict={x: X_test, y: y_test})
        print('accuracy in validation = ', acc)
        print('bias = ', b)
        print('weight = ', w)

should_stopを抜けたあとはそのSession内でいかなるOperationを受け付けないので、事後処理的なことをしたい場合は、別のSessionを立ち上げた上で実行します。今回はchiefの場合に学習後のモデルパラメータを取得するように書きました。

まとめ

Distributed Tensorflowむずい。

References

分散学習

Distributed Tensorflowで参考になったStackOverFlow