Template Haskellで遊ぶ

Ross Patersonの論文に面白いデータ構造「homogeneous functions」があって:

type Pair a = (a, a)
data BalTree a = Zero a | Succ (BalTree (Pair a)) deriving (Show)

-- the homogeneous functions
data Hom a b = ( a -> b ) :&: Hom (Pair a) (Pair b)

instance Cat.Category (Hom) where
  id = Cat.id
  (g :&: gs) . (f :&: fs) = (g . f) :&: ((Cat..) gs fs)

instance Arrow Hom where
  arr f = f :&: arr (f *** f)
  first (f :&: fs) = (f *** id) :&: (arr transpose >>> first fs >>> arr transpose)
    where transpose ((a, b), (c, d)) = ((a, c), (b, d))

apply :: Hom a b -> BalTree a -> BalTree b
apply (f :&: fs) (Zero x) = Zero (f x)
apply (f :&: fs) (Succ t) = Succ (apply fs t)


butterfly :: (Pair b -> Pair b) -> Hom b b
butterfly f = id :&: proc(o, e) -> do
                           o' <- butterfly f -< o
                           e' <- butterfly f -< e
                           returnA -< f (o', e')

rev :: Hom a a
rev = butterfly swap

bisort :: Ord a => Hom a a
bisort = butterfly cmp 
  where cmp (x, y) = (min x y, max x y)

main = do print $ apply rev (Succ (Succ (Zero ((1,2),(3,4)))))
          print $ apply bisort (Succ (Succ (Zero ((1,4),(3,2)))))


toList :: BalTree a -> [a]
toList (Zero a0) = [a0]
toList (Succ (Zero (a0, a1))) = [a0, a1]
toList (Succ (Succ (Zero ((a0, a1), (a2, a3))))) = [a0, a1, a2, a3]


toBTree :: [a] -> BalTree a
toBTree as = case length as of x | x == 2^0 -> Zero (as!!0)
                                 | x == 2^1 -> Succ (Zero (as!!0, as!!1)) 

「全パターン手で書くなんてありえねぇ!」のあなたにTemplate Haskellというc++ templateやboost.preprocessorみたいなメタプログラミングフレームワークがある。これを使ってリストとBalTreeの変換関数を試して書いた:

-- Main.hs

import BalanceTree

bTree2List' = $(bTree2List 5) -- 2^0 2^1 ... 2^5までのパターンを列挙する
list2bTree' = $(list2bTree 5)

main = print $ list2bTree' [2,1,6,3,8,5,4,7] >>= mapply bisort >>= bTree2List'
  where mapply f d = return (apply f d)

-- BalanceTree.hs

{-# LANGUAGE TemplateHaskell #-}

module BalanceTree ( 
  BalTree(Zero, Succ),
) where

import Language.Haskell.TH
import Language.Haskell.TH.Syntax

type Pair a = (a, a)
data BalTree a = Zero a | Succ (BalTree (Pair a)) deriving (Show)

-- $(bTree2List Int) -> ExpQ : BalTree a -> Maybe [a]
bTree2List :: Integer -> ExpQ
bTree2List n = [| \tree -> $(caseE [| tree |] alts) |]
  where alts = map (\x -> match (pat x) (normalB $ rhs x) [] ) [0..n] ++ [match wildP (normalB $ conE $ mkName "Nothing") []]
        pat n' = btreecon n' n'
        btreecon gn n' | n' == 0 = conP (mkName "Zero") [ let list = buildPairN (map varP (names gn)) gn in if gn == 0 then head list else tupP list ]
                       | otherwise = conP (mkName "Succ") [ btreecon gn (n'-1) ]
        buildPairN as n' | n' == 0 =  as
                         | otherwise = buildPairN (buildPair as) (n'-1)
        buildPair [] = []
        buildPair (a:b:as) = tupP [a,b] : buildPair as
        rhs n' = appE (conE $ mkName "Just") (listE $ map varE $ names n')
        names n' = [ mkName ("a" ++ show i) | i <- [0 .. 2^n'-1] ]

-- $(list2bTree Int) -> ExpQ : [a] -> Maybe (BalTree a)
list2bTree :: Integer -> ExpQ
list2bTree n = [| \list -> $(caseE (appE (varE $ mkName "length") [| list |]) [alt [| list |]]) |]
  where alt list = match (varP $ mkName "x") 
                         (guardedB $ guardExpr list ++ [normalGE (varE $ mkName "otherwise") (conE $ mkName "Nothing")])
        guardExpr list = map ( \x -> normalGE (infixE (Just $ varE $ mkName "x") (varE $ mkName "==") (Just $ litE $ IntegerL (2^x)))
                                     (appE (conE $ mkName "Just") (btreecon list x x))
                             ) [0..n]
        btreecon list gn n' | n' == 0 = appE (conE $ mkName "Zero") (let list' = buildPairN (vary list gn) gn in if gn == 0 then head list' else tupE list')
                            | otherwise = appE (conE $ mkName "Succ") (btreecon list gn (n'-1))
        buildPairN as n' | n' == 0 =  as
                         | otherwise = buildPairN (buildPair as) (n'-1)
        buildPair [] = []
        buildPair (a:b:as) = tupE [a,b] : buildPair as  
        vary list n' = [ infixE (Just list) (varE $ mkName "!!") (Just $ litE $ IntegerL i) | i <- [0 .. 2^n'-1] ]
