Control.Monad.foldM_ ではなく Data.Vector.foldM_ を使いましょう

Mutable array の prefix sum を余計なメモリを消費せずに計算したい時など、「基本的には mapM_ なんだけど、map してる関数の挙動がそれまでに map し終わった部分の結果に依存して欲しい」ような計算の時、私は foldM_ を使ってたんだけど、これが時々、恐ろしく遅い。例えば

psum :: STUArray s Int Double -> Int -> ST s ()
psum a n = foldM_ f 0 [0 .. n]
    where f x i = do y <- readArray a i
                     let z = x+y
                     readArray a i z
                     return z

なんて書くと、配列の大きさが 100M 要素ぐらいになった時点で既に 11 秒とか掛かったりする。それに対して、手動で最適化した以下のようなコードでは 1.2 秒で済むので、オーダーが丸ひとつ違ってしまう。

opt_psum :: STUArray s Int Double -> Int -> ST s ()
opt_psum a n = go 0 0
    where go x i = when (i <= n) $ do
                     y <- readArray a i
                     let z = x+y
                     writeArray a i z
                     go z (i+1)

原因は foldM_ が left-fold な事による (っぽい)。foldr で定義できる forM_ なんかと違って、foldM_ は リストを絶対に左から右へ消費していかなくてはならないので、旧来の融合変換が適用できない。*1みたい。

で、勿論上記のような手動で最適化したやつを書くのはめんどくさいので、どうすればいいかというと単に Data.Vector を使えば良い。

import qualified Data.Vector as VU
vector_eft_psum :: STUArray s Int Double -> Int -> ST s ()
vector_eft_psum a n = VU.foldM_ f 0 (VU.enumFromTo 0 n)
    where f x i = do y <- readArray a i
                     let z = x+y
                     writeArray a i z
                     return z

これだけで opt_psum とほぼ同じ性能が出る。*2マニュアルによれば enumFromTo は遅いらしいけど、この場合は融合変換できるので enumFromN とあんまり変わらない。

まあ Data.Vector 使えば融合変換が凄いよ、というのは知れ渡ってる事実なんだけど、正直リストの [0..n] という構文の直観的なわかりやすさが捨てがたくて Data.Vector は個人的に敬遠してたんですよ。だけどこれからは foldM_ を使う時は Data.Vector 一択になりそう。

*1:呼び名は忘れたけど Wadler のやつ。"Wadler deforestation" とかでググりたも。

*2:若干のオーバーヘッドはあるが、気にならないレベル。