元論文:https://openreview.net/pdf?id=YicbFdNTTy
最近話題のCNNを使わない画像認識モデルです。
ぶっちゃけた結論
ほぼBERTだと思いました。
そのため、もし今まで画像系に取り組まれてきた方でViTをきちんと理解したいのによくわからない!という方は「Transformer(Self-Attention)→BERT→ViT」と進むのが近道だと思います。
以下はもう少し細かい説明です。随所にBERTっぽい文言が登場すると思います。
ViTの開発コンセプト
Transformerをコンピュータビジョンに適用する動きこれまでにもちょくちょくありました。しかしあまり精度が出なかったり、CNNが残っていたり、元々のTransformerを改造しすぎていて一般的に利用するには敷居が高いものばかりでした。
ViTは「一般的なTransformerを出来る限りそのままの形で画像に適用し、かつ高い精度を達成する」ことをコンセプトに発明されました。実際にCNNを完全に排除した状態で高い精度を達成することに成功しています。
ViTの新規点
- 大規模データセットで学習させた事前学習モデルを、小規模データセットを用いた下流タスクでfine tuningするような使い方をする
- 入力を一般的なTransformerと揃えるために工夫した。
- 事前学習時間が既存のSoTAモデルの1/4程度。たった(?)TPU v3で3日で済む
入力画像を一般的なTransformerと同じように「埋め込みベクトル×トークン数」という形に変形できてしまえば、あとは一般的なTransformerの枠組みが使えてしまうという理屈です。
今回は分類タスクを解きたいので、TransformerのEncoder層の頭にくっつけた[CLS]トークン部分にClassification層を繋げたモデルを作り、教師あり学習を行っています。
ちなみにViTをImageNetだけで学習させると、ResNetに数%負けます。Transformerの事前学習は経験則的に「データは大きければ大きいほどいい」という風潮がありますが、その経験から考えると、ImageNetではデータの規模が足りていなかったと考えられます。実際に大規模データセット(Googleが独自に持つ300Mの画像データセット)で事前訓練し、ImageNetでファインチューニングしたらSoTAに近い結果を得られました。ImageNetで88.36%、ImageNet-ReaLで90.77%、CIFAR-100で94.55%です。
ViTの入力について
どうやって画像を「埋め込みベクトル×トークン数」にするかという話です。
やり方は図1の通りです。画像をパッチに分割し、パッチをそれぞれ線形変換して埋め込んだものをTransformerの入力とします。各パッチは自然言語処理で言うトークン(単語)と同じ働きをします。なので「埋め込みベクトル×パッチ数」という表現になります。
元の画像のサイズを(縦×横×チャンネル数)で(H×W×C)とします。これを(P²×C)のサイズを持ったN個のパッチに分割することを考えます。Pは1パッチあたりの長さで、各パッチは縦横が等しい正方形である必要があります。上記からHWC=NP²Cなので、N=HW/P²と求められます。このNの大きさがNLPでのトークンの長さに対応するものとなります。例えば元の画像の縦横が224x224で、1パッチあたりの長さが16ならならN=196という具合です。一般的なTransformerだと入力最大長=512であるため、今回もN<512になるようにしましょう。そしてP²×Cである各パッチを線形変換して埋め込んだものをパッチEmbeddingとして入力に使用します。
あとは 位置埋め込みです。Transformerが進化するにつれ、
- 三角関数を利用(Transformer)
- 相対位置を線形変換で埋め込み(BERT)
- Attentionの重みを算出するSoftmaxの手前に足し合わせる(T5)
など変化していますが、今回は相対位置を線形変換で埋め込んだものをパッチEmbeddingに足す形で使います。BERTと同じ操作です。縦横位置を考慮した埋め込みや、Attention内への埋め込みも試しましたが、精度にはほとんど影響がありませんでした。
そしてモデルの先頭には[CLS]パッチを付けます。これは画像の入力長を問わず各パッチの影響を等しく受けるため、学習後の分類タスクに利用することができます。
ここまでをまとめると、
- パッチ埋め込み+位置埋め込みが入力で、
- TransformerのEncoderを利用していて、
- 事前学習させたモデルのCLSパッチに、タスクに合わせた全結合層をくっつけてfine tuneさせる。
となります。ほぼBERTなのが伝わったかと思います。
その他工夫点
- 画像の解像度が違う場合、パッチサイズはそのままに入力長を大きくしする。
→fine tune時には解像度の大きい画像を使ったほうがいいと先行研究でわかっています。
→解像度をあげたことで足りなくなった位置Embed部分については外挿して補完します。 - 一般的なTransformerではAttentionの計算をしてから正規化していますが、今回は正規化してからAttentionの計算をします。
- 各パッチを単純に埋め込むのではなく、 各パッチに対してResNetを適用してからFlattenしたものを埋め込みベクトルとするハイブリッドな手法もアリです。
結果
これまでのSoTAとの比較は図2の通りです。ViTはAttentionを32層積み重ねたHugeサイズのものとなっています。Natural系でこそBiTにやや負けているものの、これは乱数の範囲内です。つまりほぼ全てのタスクでViTは高い性能を発揮します。
感想
今まで使ってたCNNがいらないと言われると、びっくりするような寂しいような気持ちです。Attentionはパッチ位置同士の関係を掴むのが得意なため、遠くないうちにどこに何が写っているかについてもSoTAを達成するはずです。
事前学習が手の届かない世界に行ってしまったことから、これまでよりさらに原理をきちんと理解しないと「理屈は分からないけどすごく高い精度は出る」という沼にハマりやすくなってしまったかと思います。これからもキチンとキャッチアップしていきたいと思います。