typesafe union

This commit is contained in:
Sandy Maguire 2019-03-17 14:09:09 -04:00
parent 26ffcd4f9e
commit a8b1f247aa
7 changed files with 156 additions and 71 deletions

View File

@ -126,8 +126,8 @@ main :: IO ()
main = main =
defaultMain [ defaultMain [
bgroup "Countdown Bench" [ bgroup "Countdown Bench" [
bench "faster" $ whnf TFTF.countDownFast 10000 -- bench "faster" $ whnf TFTF.countDownFast 10000
, bench "discount" $ whnf TFTF.countDown 10000 bench "discount" $ whnf TFTF.countDown 10000
, bench "freer-simple" $ whnf countDown 10000 , bench "freer-simple" $ whnf countDown 10000
, bench "mtl" $ whnf countDownMTL 10000 , bench "mtl" $ whnf countDownMTL 10000
] ]

View File

@ -50,6 +50,8 @@ tests:
- -with-rtsopts=-N - -with-rtsopts=-N
dependencies: dependencies:
- too-fast-too-free - too-fast-too-free
- inspection-testing
- hspec
benchmarks: benchmarks:
too-fast-too-free-bench: too-fast-too-free-bench:

View File

@ -1,9 +1,12 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnicodeSyntax #-}
module Control.Monad.Discount module Control.Monad.Discount
( module Control.Monad.Discount ( module Control.Monad.Discount
@ -60,6 +63,7 @@ liftEff :: Union r (Eff r) a -> Eff r a
liftEff u = F $ \kp kf -> kf $ fmap kp u liftEff u = F $ \kp kf -> kf $ fmap kp u
{-# INLINE liftEff #-} {-# INLINE liftEff #-}
raise :: Eff r a -> Eff (e ': r) a raise :: Eff r a -> Eff (e ': r) a
raise = runEff pure $ join . liftEff . hoist raise . weaken raise = runEff pure $ join . liftEff . hoist raise . weaken
{-# INLINE raise #-} {-# INLINE raise #-}
@ -129,11 +133,11 @@ runM e = runF e pure $ join . unLift . extract
run :: Eff '[] a -> a run :: Eff '[] a -> a
run = runEff id absurdU run = runEff id $ error "lol"
{-# INLINE run #-} {-# INLINE run #-}
send :: Member eff r => eff (Eff r) a -> Eff r a send :: Member e r => e (Eff r) a -> Eff r a
send = liftEff . inj send = liftEff . inj
{-# INLINE send #-} {-# INLINE send #-}

View File

@ -1,91 +1,142 @@
{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-} {-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-} {-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
module Data.OpenUnion where module Data.OpenUnion where
import Control.Monad.Discount.Effect import Control.Monad.Discount.Effect
import Unsafe.Coerce import Data.Typeable
data Dict c where Dict :: c => Dict c
data Nat = Z | S Nat
deriving Typeable
data SNat :: Nat -> * where
SZ :: SNat 'Z
SS :: Typeable n => SNat n -> SNat ('S n)
deriving Typeable
type family IndexOf (ts :: [k]) (n :: Nat) :: k where
IndexOf (k ': ks) 'Z = k
IndexOf (k ': ks) ('S n) = IndexOf ks n
type family Found (ts :: [k]) (t :: k) :: Nat where
Found (t ': ts) t = 'Z
Found (u ': ts) t = 'S (Found ts t)
class Typeable (Found r t) => Find (r :: [k]) (t :: k) where
finder :: SNat (Found r t)
instance {-# OVERLAPPING #-} Find (t ': z) t where
finder = SZ
{-# INLINE finder #-}
instance ( Find z t
, Found (_1 ': z) t ~ 'S (Found z t)
) => Find (_1 ': z) t where
finder = SS $ finder @_ @z @t
{-# INLINE finder #-}
data Union (r :: [(* -> *) -> * -> *]) (m :: * -> *) a where data Union (r :: [(* -> *) -> * -> *]) (m :: * -> *) a where
Union :: Effect e => Word -> e m a -> Union r m a Union
:: Effect (IndexOf r n)
=> SNat n
-> IndexOf r n m a
-> Union r m a
decomp :: Union (e ': r) m a -> Either (Union r m a) (e m a)
decomp (Union p a) =
case p of
SZ -> Right a
SS n -> Left $ Union n a
{-# INLINE decomp #-}
extract :: Union '[e] m a -> e m a extract :: Union '[e] m a -> e m a
extract (Union _ a) = unsafeCoerce a extract (Union SZ a) = a
extract _ = error "impossible"
{-# INLINE extract #-} {-# INLINE extract #-}
absurdU :: Union '[] m a -> b
absurdU = error "absurd, empty union"
unsafeInj :: Effect e => Word -> e m a -> Union r m a
unsafeInj w = Union w
{-# INLINE unsafeInj #-}
unsafePrj :: Word -> Union r m a -> Maybe (t m a)
unsafePrj n (Union n' x)
| n == n' = Just (unsafeCoerce x)
| otherwise = Nothing
{-# INLINE unsafePrj #-}
newtype P t r = P {unP :: Word}
class FindElem (t :: k) (r :: [k]) where
elemNo :: P t r
instance FindElem t (t ': r) where
elemNo = P 0
{-# INLINE elemNo #-}
instance {-# OVERLAPPABLE #-} FindElem t r => FindElem t (t' ': r) where
elemNo = P $ 1 + unP (elemNo :: P t r)
{-# INLINE elemNo #-}
class (FindElem e r, Effect e)
=> Member (e :: (* -> *) -> * -> *)
(r :: [(* -> *) -> * -> *]) where
inj :: Monad m => e m a -> Union r m a
prj :: Union r m a -> Maybe (e m a)
instance (Effect t, FindElem t r) => Member t r where
inj = unsafeInj $ unP (elemNo :: P t r)
{-# INLINE inj #-}
prj = unsafePrj $ unP (elemNo :: P t r)
{-# INLINE prj #-}
decomp :: Union (t ': r) m a -> Either (Union r m a) (e m a)
decomp (Union 0 a) = Right $ unsafeCoerce a
decomp (Union n a) = Left $ Union (n - 1) a
{-# INLINE [2] decomp #-}
weaken :: Union r m a -> Union (e ': r) m a weaken :: Union r m a -> Union (e ': r) m a
weaken (Union n a) = Union (n + 1) a weaken (Union n a) =
case induceTypeable n of
Dict -> Union (SS n) a
{-# INLINE weaken #-}
inj
:: forall r e a m
. ( Functor m
, Find r e
, Effect e
, e ~ IndexOf r (Found r e)
)
=> e m a
-> Union r m a
inj e = Union (finder @_ @r @e) e
{-# INLINE inj #-}
induceTypeable :: SNat n -> Dict (Typeable n)
induceTypeable SZ = Dict
induceTypeable (SS _) = Dict
{-# INLINE induceTypeable #-}
type Member e r = (Find r e, e ~ IndexOf r (Found r e), Effect e)
prj
:: forall r e a m
. ( Find r e
, e ~ IndexOf r (Found r e)
)
=> Union r m a
-> Maybe (e m a)
prj (Union (s :: SNat n) a) =
case induceTypeable s of
Dict ->
case eqT @n @(Found r e) of
Just Refl -> Just a
Nothing -> Nothing
{-# INLINE prj #-}
instance Effect (Union r) where instance Effect (Union r) where
weave s f (Union w e) = Union w $ weave s f e weave s f (Union w e) = Union w $ weave s f e
{-# INLINE weave #-} {-# INLINE weave #-}
instance Functor m => Functor (Union r m) where
fmap f (Union w t) = Union w $ fmap f t instance (Functor m) => Functor (Union r m) where
fmap f (Union w t) = Union w $ fmap' f t
where
fmap' :: (Functor m, Effect f) => (a -> b) -> f m a -> f m b
fmap' = fmap
{-# INLINE fmap' #-}
{-# INLINE fmap #-} {-# INLINE fmap #-}

View File

@ -31,10 +31,12 @@ data State s m a
get :: Member (State s) r => Eff r s get :: Member (State s) r => Eff r s
get = send $ Get id get = send $ Get id
{-# INLINE get #-}
put :: Member (State s) r => s -> Eff r () put :: Member (State s) r => s -> Eff r ()
put s = send $ Put s () put s = send $ Put s ()
{-# INLINE put #-}
data Error e m a data Error e m a
@ -85,6 +87,7 @@ runRelayS pure' bind' = flip go
$ fmap (uncurry (flip id)) $ fmap (uncurry (flip id))
$ weave (s', ()) (uncurry $ flip go) x $ weave (s', ()) (uncurry $ flip go) x
Right eff -> bind' eff Right eff -> bind' eff
{-# INLINE runRelayS #-}
runError :: Eff (Error e ': r) a -> Eff r (Either e a) runError :: Eff (Error e ': r) a -> Eff r (Either e a)

View File

@ -3,10 +3,11 @@
module Wtf where module Wtf where
import Control.Monad.Discount
import TRYAGAIN import TRYAGAIN
import Data.Functor.Identity import Data.Functor.Identity
go :: Eff '[State Int, Lift (Identity)] Int go :: Eff '[State Int] Int
go = do go = do
n <- send (Get id) n <- send (Get id)
if n <= 0 if n <= 0

View File

@ -1,2 +1,26 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TemplateHaskell #-}
import Test.Inspection
import Control.Monad.Discount
import Data.OpenUnion
import TRYAGAIN hiding (main)
main :: IO () main :: IO ()
main = putStrLn "Test suite not yet implemented" main = pure ()
go :: Eff '[State Int] Int
go = do
n <- send (Get id)
if n <= 0
then pure n
else do
send $ Put (n-1) ()
go
countDown :: Int -> Int
countDown start = fst $ run $ runState start go
inspect $ 'countDown `hasNoType` ''SNat