Template Haskellで遊ぶ

Ross Patersonの論文に面白いデータ構造「homogeneous functions」があって:

type Pair a = (a, a)
data BalTree a = Zero a | Succ (BalTree (Pair a)) deriving (Show)

-- the homogeneous functions
data Hom a b = ( a -> b ) :&: Hom (Pair a) (Pair b)

instance Cat.Category (Hom) where
  id = Cat.id
  (g :&: gs) . (f :&: fs) = (g . f) :&: ((Cat..) gs fs)

instance Arrow Hom where
  arr f = f :&: arr (f *** f)
  first (f :&: fs) = (f *** id) :&: (arr transpose >>> first fs >>> arr transpose)
    where transpose ((a, b), (c, d)) = ((a, c), (b, d))

apply :: Hom a b -> BalTree a -> BalTree b
apply (f :&: fs) (Zero x) = Zero (f x)
apply (f :&: fs) (Succ t) = Succ (apply fs t)

BalTreeは平衡二分木。しかしここの例では直接BalTreeいじるんではなく、Hom、つまりアルゴリズム自体を作成してからBalTreeに適用するって感じ:

butterfly :: (Pair b -> Pair b) -> Hom b b
butterfly f = id :&: proc(o, e) -> do
                           o' <- butterfly f -< o
                           e' <- butterfly f -< e
                           returnA -< f (o', e')

rev :: Hom a a
rev = butterfly swap

bisort :: Ord a => Hom a a
bisort = butterfly cmp 
  where cmp (x, y) = (min x y, max x y)

main = do print $ apply rev (Succ (Succ (Zero ((1,2),(3,4)))))
          print $ apply bisort (Succ (Succ (Zero ((1,4),(3,2)))))

いや〜高階関数を自在に操る人って凄いよね〜憧れるよね〜
しかしBalTree自身は使いにくいだけど…なぜかというとBalTreeはData.Treeのような再帰構造ではない、BalTree中身のを参照したい時に唯一な方法はパターンマッチ、例えばリストに変換したい場合:

toList :: BalTree a -> [a]
toList (Zero a0) = [a0]
toList (Succ (Zero (a0, a1))) = [a0, a1]
toList (Succ (Succ (Zero ((a0, a1), (a2, a3))))) = [a0, a1, a2, a3]
...

逆も同じ、ガードで全てのパターンを手で書くしかない:

toBTree :: [a] -> BalTree a
toBTree as = case length as of x | x == 2^0 -> Zero (as!!0)
                                 | x == 2^1 -> Succ (Zero (as!!0, as!!1)) 
...

「全パターン手で書くなんてありえねぇ!」のあなたにTemplate Haskellというc++ templateやboost.preprocessorみたいなメタプログラミングフレームワークがある。これを使ってリストとBalTreeの変換関数を試して書いた:

---------------------------
-- Main.hs
---------------------------

import BalanceTree

bTree2List' = $(bTree2List 5) -- 2^0 2^1 ... 2^5までのパターンを列挙する
list2bTree' = $(list2bTree 5)

main = print $ list2bTree' [2,1,6,3,8,5,4,7] >>= mapply bisort >>= bTree2List'
  where mapply f d = return (apply f d)

---------------------------
-- BalanceTree.hs
---------------------------

{-# LANGUAGE TemplateHaskell #-}

module BalanceTree ( 
  Pair(..),
  BalTree(Zero, Succ),
  bTree2List,
  list2bTree
) where

import Language.Haskell.TH
import Language.Haskell.TH.Syntax

type Pair a = (a, a)
data BalTree a = Zero a | Succ (BalTree (Pair a)) deriving (Show)

-- $(bTree2List Int) -> ExpQ : BalTree a -> Maybe [a]
bTree2List :: Integer -> ExpQ
bTree2List n = [| \tree -> $(caseE [| tree |] alts) |]
  where alts = map (\x -> match (pat x) (normalB $ rhs x) [] ) [0..n] ++ [match wildP (normalB $ conE $ mkName "Nothing") []]
        pat n' = btreecon n' n'
        btreecon gn n' | n' == 0 = conP (mkName "Zero") [ let list = buildPairN (map varP (names gn)) gn in if gn == 0 then head list else tupP list ]
                       | otherwise = conP (mkName "Succ") [ btreecon gn (n'-1) ]
        buildPairN as n' | n' == 0 =  as
                         | otherwise = buildPairN (buildPair as) (n'-1)
        buildPair [] = []
        buildPair (a:b:as) = tupP [a,b] : buildPair as
        rhs n' = appE (conE $ mkName "Just") (listE $ map varE $ names n')
        names n' = [ mkName ("a" ++ show i) | i <- [0 .. 2^n'-1] ]

-- $(list2bTree Int) -> ExpQ : [a] -> Maybe (BalTree a)
list2bTree :: Integer -> ExpQ
list2bTree n = [| \list -> $(caseE (appE (varE $ mkName "length") [| list |]) [alt [| list |]]) |]
  where alt list = match (varP $ mkName "x") 
                         (guardedB $ guardExpr list ++ [normalGE (varE $ mkName "otherwise") (conE $ mkName "Nothing")])
                         []
        guardExpr list = map ( \x -> normalGE (infixE (Just $ varE $ mkName "x") (varE $ mkName "==") (Just $ litE $ IntegerL (2^x)))
                                     (appE (conE $ mkName "Just") (btreecon list x x))
                             ) [0..n]
        btreecon list gn n' | n' == 0 = appE (conE $ mkName "Zero") (let list' = buildPairN (vary list gn) gn in if gn == 0 then head list' else tupE list')
                            | otherwise = appE (conE $ mkName "Succ") (btreecon list gn (n'-1))
        buildPairN as n' | n' == 0 =  as
                         | otherwise = buildPairN (buildPair as) (n'-1)
        buildPair [] = []
        buildPair (a:b:as) = tupE [a,b] : buildPair as  
        vary list n' = [ infixE (Just list) (varE $ mkName "!!") (Just $ litE $ IntegerL i) | i <- [0 .. 2^n'-1] ]

メタプログラミング共通の問題点なんだけど、まずは読めない、と、その前に読みたくない。まぁこの辺本来ならコンパイラの仕事だからしょうがないかもしれないね。