Static indices for speedy array-based algorithms
GHC.TypeLits
gives us some very powerful type-level functionality (with some
caveats that will hopefully be worked out in future) that has not yet seen very
widespread use. With some mildly ugly leg-work, I’ve made a small library that
provides a n-dimensional statically bounded index type, and associated functions
for their application with as minimal overhead as possible.
Motivation
Making plissken, I was unhappy
with the state of linear algebra libraries on Hackage. The venerable
hmatrix is excellent - once your
input sizes outweigh the constant factors involved. For the matrices
that my game engine lived and survived on – 4x4 matrices and 4-vectors –
hmatrix was easily outperformed in simple 4x4 matrix multiplication by
linear, which is a naiive Haskell
implementation of the sort of “small” linear algebra that I need here. I ended
up using hmatrix though, since it being built on Storable
made for very easy
interaction with OpenGL.
A trick I found in order to inline away “static recursion” was to abstract the
use of the recursion parameter into a typeclass whose function is inlined. So
long as you remember the all-powerful -O2
switch, GHC will happily inline away
every recursive call, which will hopefully give way to further compile-time
evaluation.
Here’s a little demonstration of such a class:
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE Undecidableinstances #-} -- our use of this should be fine.
module Playaround where
import Control.Applicative
import GHC.TypeLits
data Peano = Succ Peano | Zero
class For (n :: Peano) where
for :: Applicative m => Proxy n -> (Integer -> m ()) -> m () -> m ()
instance For Zero where
{-# INLINE for #-}
= acc
for _ _ acc
instance (KnownNat (1+FromPeano n), For n) => For (Succ n) where
{-# INLINE for #-}
= for (next p) f (acc *> f (peanoVal p))
for p f acc
next :: Proxy (Succ n) -> Proxy n
= Proxy
next _
type family FromPeano (n :: Peano) :: Nat where
FromPeano Zero = 0
FromPeano (Succ n) = 1 + FromPeano n
type family ToPeano (n :: Nat) :: Peano where
ToPeano 0 = Zero
ToPeano n = Succ (ToPeano (n-1))
fromPeano :: Proxy (n :: Peano) -> Proxy (FromPeano n)
= Proxy
fromPeano _
toPeano :: Proxy (n :: Nat) -> Proxy (ToPeano n)
= Proxy
toPeano _
peanoVal :: KnownNat (FromPeano n) => Proxy (n :: Peano) -> Integer
= fromInteger . natVal . fromPeano peanoVal
That’s pretty gross. We can improve it by wrapping it in a small helper:
niceFor :: (For (ToPeano n), Applicative m)
=> Proxy (n :: Nat) -> (Integer -> m ()) -> m ()
= for (toPeano p) f (pure ()) niceFor p f
Unfortunately GHC doesn’t see as we can intuitively, that since all Nat
has an
instance for ToPeano
, For
has the instance For (ToPeano (n :: Nat))
, so
we have to add that constraint.
Let’s try it out:
{-# LANGUAGE DataKinds #-}
module Main where
import Playaround
import qualified Data.Vector.Mutable as VM
main :: IO ()
= do
main <- M.new 16
vect Proxy :: Proxy 15) $ \ix ->
niceFor ( M.write vect ix ix
$ ghc-core Main.hs
-dsuppress-idinfo \
-dsuppress-coercions \
-dsuppress-type-applications \
-dsuppress-uniques \
-dsuppress-module-prefixes
Some scrolling later…
main1 :: State# RealWorld -> (# State# RealWorld, () #)
=
main1 eta :: State# RealWorld) ->
\ (case newArray# 16 (uninitialised) (eta `cast` ...)
of _ { (# ipv, ipv1 #) ->
case writeArray# ipv1 15 (I# 15) ipv of s'# { __DEFAULT ->
case writeArray# ipv1 14 (I# 14) s'# of s'#1 { __DEFAULT ->
case writeArray# ipv1 13 (I# 13) s'#1 of s'#2 { __DEFAULT ->
case writeArray# ipv1 12 (I# 12) s'#2 of s'#3 { __DEFAULT ->
case writeArray# ipv1 11 (I# 11) s'#3 of s'#4 { __DEFAULT ->
case writeArray# ipv1 10 (I# 10) s'#4 of s'#5 { __DEFAULT ->
case writeArray# ipv1 9 (I# 9) s'#5 of s'#6 { __DEFAULT ->
case writeArray# ipv1 8 (I# 8) s'#6 of s'#7 { __DEFAULT ->
case writeArray# ipv1 7 (I# 7) s'#7 of s'#8 { __DEFAULT ->
case writeArray# ipv1 6 (I# 6) s'#8 of s'#9 { __DEFAULT ->
case writeArray# ipv1 5 (I# 5) s'#9 of s'#10 { __DEFAULT ->
case writeArray# ipv1 4 (I# 4) s'#10 of s'#11 { __DEFAULT ->
case writeArray# ipv1 3 (I# 3) s'#11 of s'#12 { __DEFAULT ->
case writeArray# ipv1 2 (I# 2) s'#12 of s'#13 { __DEFAULT ->
case writeArray# ipv1 1 (I# 1) s'#13 of s'#14 { __DEFAULT ->
# s'#14, () #) `cast` ...
( } } } } } } } } } } } } } } } }
Great! The loop is unrolled rather nicely.
Enter indices
indices
provides a multi-dimensional (Int
based) index type, with use for
array-heavy code in mind. I hope Soon® to release on Hackage the package
static (patches
welcome and appreciated) which provides a ForiegnPtr
based array type that
leverages indices
’s, ahem, indices, just about everywhere.
Indices in indices
are these two types:
data (a :: Nat) :. b = !Int :. b
data Z = Z
This is similar to the design seen in
repa, except the “bound” of an index
is its type. That’s crippling for code with arbitrarily-bounded arrays, but very
nice otherwise. For instance, an index into a 4x4 matrix is 0:.0:.Z :: 4:.4:.Z
.
For now, you can use indices
for array-based code by leveraging its Ix
instance. This is a little confusing due to the design of Ix
, but it’s fairly
simple: the type is always the upper bound, and zero is always the lower bound,
not a value you give. That means that arrays constructed by array (_, _ :: t)
are bounded by [0,t). As an aside, I’m leaning towards reworking static
to
simply use the array types found in
arrays to simplify it greatly into
the small linear-algebra package it yearns to be. Input is appreciated here. (At
the moment I’m irked at having to store two superfluous lower/upper bounds –
linked lists of Int
– in the array constructors).
Here’s a demonstration, implementing vector dot product with a static, unrolled loop:
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module MM where
import Data.Array.Unboxed
import Data.Index
type Vector m = UArray (m:.Z)
sizeV :: Vector m a -> Proxy (m:.Z)
= Proxy
sizeV _
vector :: (IArray UArray a, Dim (m:.Z)) => Proxy (m:.Z) -> [a] -> Vector m a
= listArray (zero, maxBound `asProxyTypeOf` b)
vector b
=
dot a b
sfoldlRange`asTypeOf` sizeV b)
(sizeV a sum ix -> sum + a!ix * b!ix)
(\0
= vector (Proxy :: Proxy (4:.Z)) [x,y,z,w]
v4 x y z w
main :: IO ()
= do
main print (dot (v4 1 2 3 4) (v4 2 3 4 5) :: Double)
print (sum (zipWith (*) [1,2,3,4] [2,3,4,5]) :: Double)
Note the slyness in me not writing the type signature to dot
… Well, the
important thing is that GHC happily infers the type without needing any hints.
Right… guys?
You can contribute to indices
on GitHub, find its documentation
on Hackage, and install it with
cabal
.
blog comments powered by Disqus