2021-12-04

Clojureで行列計算を実装したときに考えたトリック

これはClojureのアドベントカレンダー2021、4日目の記事です。


趣味でディープニューラルネットワークのフレームワークをClojureでゼロから実装しています。

「ゼロから」とは行列計算から。

deltam/neuro: Deep Neural Network written in Clojure from scratch

そのときに行列計算についてあるトリックを思いついて実装したら早くなったのでそれを紹介してみます。関数的プログラミングっぽい発想で、ほかにも応用が効くかもしれない。Clojureのデータ構造の勘所もすこし分かるかも。

注意:完全なる趣味の行いなので単に行列計算したいだけならそれ用のライブラリを使うほうが100倍かしこい。
実用的な知識というよりちょっとしたトリックを面白がってくれればありがたい。

ニューラルネットの学習では行列の転置をよく使うんですが、プロファイルを取ってみたらそこが遅かったので何とかしたいといろいろやって思いついたトリックです。
思ったより応用が効きました。

最初の実装

行列の表現

$$ \begin{pmatrix} 1 & 2 & 3 \\\\ 4 & 5 & 6 \end{pmatrix} $$

上の行列を表現するには行・列のサイズ(shape)と中身の値を持っておく必要があります。

行列は素朴に考えると二次元配列だけど、入れ子の配列を操作するのは面倒なので一次元にしてます。

(defn mat [col row v]
  {:shape [col row]
   :vec (vec v)})

(def m1 (mat 2 3 [1 2 3 4 5 6]))

つぎに行列の要素を取り出す関数を書きたい。そのためには行・列から:vecのインデックスに変換する関数が必要。
両方をあわせてこのように書いておく。

(defn cr->i [v c r]
   (let [[_ row] (:shape v)]
      (+ (* c row) r)))

(defn mget [m c r]
  (let [i (cr->i m c r)]
    (nth (:vec m) i)))

コードを書いてると行と列がごっちゃになって混乱するので表示する関数も書いておくと便利です。

(defn show [m]
  (let [[col row] (:shape m)]
    (doseq [c (range col), r (range row)]
      (printf " %d " (mget m c r))
      (if (= r (dec row))
        (println)))))

;mat1=> (show m1)
; 1  2  3
; 4  5  6

転置とベンチマーク

さてこれで行列の表現と値を取り出す操作が用意できた。これを使って基本的な行列演算である転置を書いてみる。

転置はある行列を対角線に沿って反転させる操作です。

$$ \begin{eqnarray} A = \left( \begin{array}{ccc} 1 & 2 & 3 \\\\ 4 & 5 & 6 \end{array} \right) \end{eqnarray} $$ $$ \begin{eqnarray} A^{ \mathrm{ T } } = \left( \begin{array}{ccc} 1 & 4 \\\\ 2 & 5 \\\\ 3 & 6 \\\\ \end{array} \right) \end{eqnarray} $$

(defn transpose [m]
  (let [[col row] (:shape m)
        v (for [r (range row), c (range col)]
            (mget m c r))]
    (mat row col v)))

;mat1=> (def m1 (mat 2 3 [1 2 3 4 5 6]))
;#'user/m1
;mat1=> (show m1)
; 1  2  3
; 4  5  6
;nil
;mat1=> (show (mat1/transpose m1))
; 1  4
; 2  5
; 3  6

素朴なベンチマークとして指定回数だけ転置を繰り返す関数を作った。

(defn bench [lim]
  (let [m (mat 100 100 (range 10000))]
    (dotimes [n lim]
      (do (transpose m)))))

;mat1=> (time (bench 10000))
;"Elapsed time: 26856.948132 msecs"

10000回でも結構時間がかかっている。

トリック1: 転置フラグ

よく考えると転置は行・列と実際の配列インデックスの紐付けを変更するだけの処理です((i,j) -> (j,i))。「インデックスを計算するとき転置させたか判断してうまく変換してやれば:vecを書き換えなくてもOKなのでは?」と思いつきました。

そこで転置済みかのフラグ:transposed?を追加してみた。

(defn mat [col row v]
  {:shape [col row]
   :vec (vec v)
   :transposed? false})

(defn cr->i [v c r]
   (let [[col row] (:shape v)]
      (if (:transposed? m)
        (+ c (* r col))
        (+ (* c row) r))))

転置は以下のように:shape:transposed?を反転させるだけ。

(defn transpose [m]
  (-> m
      (update :shape reverse)
      (update :transposed? not)))

;mat2=> (def m1 (mat 2 3 [1 2 3 4 5 6]))
;#'user/m1
;mat2=> m1
;{:shape [2 3], :vec [1 2 3 4 5 6], :transposed? false}
;mat2=> (transpose m1)
;{:shape (3 2), :vec [1 2 3 4 5 6], :transposed? true}

;mat2=> (show m1)
; 1  2  3
; 4  5  6
;nil
;mat2=> (show (mat2/transpose m1))
; 1  4
; 2  5
; 3  6
;nil

:vecの内容は変わらないけど、転置行列として扱うことができます。

最初と同じbenchを使って簡易ベンチマークを取ってみた。

mat2=> (time (bench 10000))
"Elapsed time: 7.434236 msecs"

やった! クッソ早い!

ClojureのPersistentMap

なんでこんなに早くなったのか。それにはClojureのMapの実装が関係しています。

Clojureでは値は不変(Immutable)なので、Mapを書き換える関数は新たなMapを生成して返します。しかし本当に全部コピーして新たなMapを作っていると非効率なので変更されていないKey-Valueは元の値と共有されています。そのようなデータ構造を永続データ構造(Persistent Data Structure)1と言うらしい。

最初の実装では:vecを次々書き換えていくため、共有されるValueが少なく新たに作られるMapの負荷が大きかった。二番目の:transposed?を使う実装だと変更後との差分が少なくPersistentMapの特性とマッチしていて効率的にMap作成が行うことができた(:vecを新たに作る処理も省略できた)。直感に反するけども、サイズの大きいValueについては編集せずそのままコピーしたほうが効率的なようです。

Clojureの永続データ構造の振る舞いについては以下の解説記事が参考になります。
PersistentVectorについての解説だけどPersistentMapの理解にも役立つ。

Clojure設計者のRich Hickey自身が解説している講演のTranscriptも参考になる。

talk-transcripts/AreWeThereYet.md at master · matthiasn/talk-transcripts

トリック2: 関数リテラルをもたせる

DNNフレームワークでもう一つ重たい処理がありました。学習データを一つの巨大な行列として作って、それを行ごとに切り分けてミニバッチという単位にして学習に使うのだけど、その切り分けに時間がかかっていました。

上記のトリックと永続データ構造の特性を考えていて新たなトリックを閃いた。

「データ構造にインデックス変換の関数それ自体を持たせれば良いのでは?」

転置行列用のフラグを用意するのではなく、インデックス変換関数それ自体を持たせて転置のときはそれをWrapすればいい。切り分けるのもそれで対応できそう。

つまり行列の表現を次のように書き換えます。

(defn mat [col row v]
  {:shape [col row]
   :vec (vec v)
   :posf (fn [c r] (+ (* c row) r))})

(defn cr->i [m c r]
  ((:posf m) c r))

:posfがインデックス変換の関数。これを値として持たせてしまう。
転置はこのように書ける。

(defn transpose [m]
  (let [f (:posf m)]
    (-> m
        (update :shape reverse)
        (assoc :posf (fn [c r] (f r c)))))) ; 引数の順番を入れ替えるだけ

さらにこの方法だと行ごとに切り分ける関数も同じトリックでいける。

(defn slice [m start end]
  (let [[col row] (:shape m)
        size (- end start)
        f (:posf m)]
    (-> m
        (assoc :shape [size row])
        (assoc :posf (fn [c r]
                       (f (+ c start) r))))))

;mat3=> (def m3 (mat 3 2 [1 2 3 4 5 6]))
;#'user/m3
;mat3=> (show m3)
; 1  2
; 3  4
; 5  6
;nil
;mat3=> (show (mat3/slice m3 1 3))
; 3  4
; 5  6
;nil

同じように転置についてベンチマークを取ってみた。

mat3=> (time (bench 10000))
"Elapsed time: 12.941026 msecs"

:transposed?フラグを使う方法と同じくらい早い。だが応用範囲はこっちのほうが広い。

応用1: ランダムシャッフル

転置行列、行列の切り分けで説明したが、こんな処理も書ける。

行をランダムシャッフルする関数。行のインデックスをシャッフルしたシーケンスを関数リテラルの中に持たせている。

(defn shuffle-col [m]
  (let [[col _] (:shape m)
        rs (shuffle (range col))
        f (:posf m)]
    (assoc m :posf (fn [c r]
                     (f (nth rs c) r)))))

;mat3=> (show m3)
; 1  2
; 3  4
; 5  6
;nil
;mat3=> (show (shuffle-col m3))
; 5  6
; 1  2
; 3  4

応用2: One-Hot ベクトルの表現

機械学習ではあるデータがカテゴリに属していることを表すのにOne-Hotベクトルというものを使います。これは特定の列だけ1でそれ以外は0のベクトル。
このトリックを使うとほとんどメモリを使わずに表現できる。

(defn one-hot [size hot-idx]
  {:shape [1 size]
   :vec [0 1]
   :posf (fn [_ r]
           (if (= r hot-idx)
             1
             0))})

;mat3=> (show (one-hot 10 3))
; 0  0  0  1  0  0  0  0  0  0

応用3: ブロードキャスト

同じ行を指定数だけ増やす関数も書ける。numpyではこれを自動でやってくれてるブロードキャスト機能というものがある。

(defn broadcast-row [m-row col-size]
  (let [[_ row] (:shape m-row)
        f (:posf m-row)]
    (-> m-row
        (assoc :shape [col-size row])
        (assoc :posf (fn [_ r]
                       (f 0 r))))))

;mat3=> (def m4 (mat 1 4 [1 2 3 4]))
;#'user/m4
;mat3=> (show m4)
; 1  2  3  4
;nil
;mat3=> (show (broadcast-row m4 3))
; 1  2  3  4
; 1  2  3  4
; 1  2  3  4

応用4: 組み合わせ

関数をWrapしているだけなので組み合わせても問題ない。レキシカルクロージャなのでその時点でのrow,colを保存しておけるという利点が効いている。

mat3=> (def m3 (mat 3 2 [1 2 3 4 5 6]))
#'user/m3
mat3=> (show m3)
 1  2
 3  4
 5  6
nil

mat3=> (-> m3
           (slice 1 2) ; 2行目を切り出し
           (transpose) ; 転置する
           (show))
 3
 4

mat3=> (-> (one-hot 10 3)    ; OneHotベクトルを生成
           (broadcast-row 3) ; 3行に増やす
           (transpose)       ; 転置する
           (show))
 0  0  0
 0  0  0
 0  0  0
 1  1  1
 0  0  0
 0  0  0
 0  0  0
 0  0  0
 0  0  0
 0  0  0

このトリックはかなり応用が効く。

弱点

このトリックには一つ弱点があります。変更に弱いことです。

行列の一部を変更する関数をこのように定義してみる。

(defn mput [m c r v]
  (let [i (cr->i m c r)]
    (assoc-in m [:vec i] v)))

ふつうの行列については正しく動く。

mat3=> (def m1 (mat 2 3 [1 2 3 4 5 6]))
#'mat3/m1
mat3=> m1
{:shape [2 3], :vec [1 2 3 4 5 6], :posf #object[mat3$mat$fn__148 0x1021f6c9 "mat3$mat$fn__148@1021f6c9"]}
mat3=> (show m1)
 1  2  3
 4  5  6

mat3=> (show (mput m1 0 0 10)) ; 左上を10に変更
 10  2  3
 4  5  6

しかしトリックを組み合わせた行列にそのまま適用するとおかしなことになる。表示される行列は見かけのものであり、実態は2つの要素[0 1]だけだから。

mat3=> (def m4 (-> (one-hot 10 3)
                   (broadcast-row 3)))
#'mat3/m4
mat3=> (show m4)
 0  0  0  1  0  0  0  0  0  0
 0  0  0  1  0  0  0  0  0  0
 0  0  0  1  0  0  0  0  0  0
nil
mat3=> (show (mput m4 0 0 99))  ; 左上を99に置き換える -> 失敗
 99  99  99  1  99  99  99  99  99  99
 99  99  99  1  99  99  99  99  99  99
 99  99  99  1  99  99  99  99  99  99
nil

これを回避するために見かけの行列を反映したコピーを作る関数を用意する。

(defn clone [m]
  (let [[col row] (:shape m)
        v (for [c (range col), r (range row)]
            (mget m c r))]
    (mat col row v)))

;mat3=> (show (mput (clone m4) 0 0 99)) ; 左上を99に置き換える -> 成功
; 99  0  0  1  0  0  0  0  0  0
; 0  0  0  1  0  0  0  0  0  0
; 0  0  0  1  0  0  0  0  0  0
;nil

行列をファイル出力するときなどもcloneは必要でしょう。

ニューラルネット学習では更新する必要があるのは全結合層の重み行列です。これは逆伝播のたびに全要素を更新していくのでこのトリックを使う余地はなく問題なく更新できる。One-hotベクトルのトリックは学習データに使っているがこれは入力データなので更新されない。

たまたまだがこのトリックは用途にベストマッチしていた。

最終的にやった実装がこれ2

neuro/vol.clj at master · deltam/neuro

まとめ

データ構造に関数リテラルをもたせると柔軟な表現が可能になります。
そのトリックはClojureのPersistentMapの特性とも合致していて効率化も果たせました。
関数型っぽい行列計算の実装ができ、趣味プロジェクトが進んで満足。

だけど趣味ではないコードで行列演算が必要な場合は自分で書かずにライブラリを探しましょう(令和の時代に行列掛け算のバグに泣かされるべきでない)。

おわり。


  1. 永続データ構造自体については英語版Wikipediaがわかりやすい。Clojureでどのように使われているかも触れられている。 ↩︎

  2. 余談:numpyも同じようなことをやっているのかとコードや記事を読んでみたが違うみたいだった。汎用的に効率のいいコードは大変だー。NumPyで使われる多次元配列のデータ構造「ndarray」とは?:CodeZine(コードジン) ↩︎

0 件のコメント: