Self-Attentionを利用したテキスト分類

Page content

Self-Attentionを利用したテキスト分類

TL;DR

テキスト分類問題を対象に、LSTMのみの場合とSelf-Attentionを利用する場合で精度にどのような差がでるのかを比較しました。 結果、テキスト分類問題においても、Self-Attentionを利用することで、LSTMのみを利用するよりも高い精度を得られることが確認できました。

Self-Attentionの実装としてはkeras-self-attentionを利用しました。

ベンチマーク用データ

京都大学情報学研究科–NTTコミュニケーション科学基礎研究所 共同研究ユニットが提供するブログの記事に関するデータセットを利用しました。 このデータセットでは、ブログの記事に対して以下の4つの分類がされています。

  • グルメ
  • 携帯電話
  • 京都
  • スポーツ

LSTMのみ利用した場合のモデルと結果

text-vectorianを利用してベクトル表現に変換したテキスト入力し、LSTMにより学習、 分類先となる4ラベルの何れに該当するかを推論するモデルです。

モデル

クラシフィケーションレポート

F1値は0.76でした。

              precision    recall  f1-score   support

          京都       0.71      0.75      0.73       137
        携帯電話       0.80      0.81      0.80       145
        スポーツ       0.68      0.72      0.70        47
         グルメ       0.84      0.72      0.78        90

   micro avg       0.76      0.76      0.76       419
   macro avg       0.76      0.75      0.75       419
weighted avg       0.77      0.76      0.76       419

Self-Attentionを利用した場合のモデルと結果

LSTMのみを利用した場合のモデルSelf-Attentionの層を追加したものです。 Self-Attentionの出力は入力と同じ(sample, time, dim)の3階テンソルであるため、GlobalMaxPooling1DによりShapeを変換しています。

モデル

クラシフィケーションレポート

F1値が0.79でした。

              precision    recall  f1-score   support

          京都       0.71      0.84      0.77       137
        携帯電話       0.88      0.78      0.83       145
        スポーツ       0.66      0.74      0.70        47
         グルメ       0.88      0.74      0.81        90

   micro avg       0.79      0.79      0.79       419
   macro avg       0.78      0.78      0.78       419
weighted avg       0.80      0.79      0.79       419

総括

LSTMのみでは0.76であったF1値がSelf-Attentionを追加することで3%増加し0.79に向上しました。 それぞれ複数回試しても±1-2%程度の誤差範囲でしたので、Self-Attentionを追加することは有意であると考えられます。

AttentionNMT(Neural Machine Translation)のようなシーケンスからシーケンスを推論する問題(seq2seq)への適用が注目されますが、 単純な分類問題にも入出力の形を変えずにそのまま適用出来るため、とりあえず使って見るのは悪く無さそうです。

参考文献