From a8b1f247aa9ccca09a48720386eefc467e0ce84a Mon Sep 17 00:00:00 2001 From: Sandy Maguire Date: Sun, 17 Mar 2019 14:09:09 -0400 Subject: [PATCH] typesafe union --- bench/countDown.hs | 4 +- package.yaml | 2 + src/Control/Monad/Discount.hs | 20 ++-- src/Data/OpenUnion.hs | 169 ++++++++++++++++++++++------------ src/TRYAGAIN.hs | 3 + src/Wtf.hs | 3 +- test/Spec.hs | 26 +++++- 7 files changed, 156 insertions(+), 71 deletions(-) diff --git a/bench/countDown.hs b/bench/countDown.hs index 2a49957..d1db876 100644 --- a/bench/countDown.hs +++ b/bench/countDown.hs @@ -126,8 +126,8 @@ main :: IO () main = defaultMain [ bgroup "Countdown Bench" [ - bench "faster" $ whnf TFTF.countDownFast 10000 - , bench "discount" $ whnf TFTF.countDown 10000 + -- bench "faster" $ whnf TFTF.countDownFast 10000 + bench "discount" $ whnf TFTF.countDown 10000 , bench "freer-simple" $ whnf countDown 10000 , bench "mtl" $ whnf countDownMTL 10000 ] diff --git a/package.yaml b/package.yaml index 97132ab..1f9d21f 100644 --- a/package.yaml +++ b/package.yaml @@ -50,6 +50,8 @@ tests: - -with-rtsopts=-N dependencies: - too-fast-too-free + - inspection-testing + - hspec benchmarks: too-fast-too-free-bench: diff --git a/src/Control/Monad/Discount.hs b/src/Control/Monad/Discount.hs index 1624c53..75e9975 100644 --- a/src/Control/Monad/Discount.hs +++ b/src/Control/Monad/Discount.hs @@ -1,9 +1,12 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE MonoLocalBinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UnicodeSyntax #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnicodeSyntax #-} module Control.Monad.Discount ( module Control.Monad.Discount @@ -60,6 +63,7 @@ liftEff :: Union r (Eff r) a -> Eff r a liftEff u = F $ \kp kf -> kf $ fmap kp u {-# INLINE liftEff #-} + raise :: Eff r a -> Eff (e ': r) a raise = runEff pure $ join . liftEff . hoist raise . weaken {-# INLINE raise #-} @@ -129,11 +133,11 @@ runM e = runF e pure $ join . unLift . extract run :: Eff '[] a -> a -run = runEff id absurdU +run = runEff id $ error "lol" {-# INLINE run #-} -send :: Member eff r => eff (Eff r) a -> Eff r a +send :: Member e r => e (Eff r) a -> Eff r a send = liftEff . inj {-# INLINE send #-} diff --git a/src/Data/OpenUnion.hs b/src/Data/OpenUnion.hs index 089fd67..761bfb4 100644 --- a/src/Data/OpenUnion.hs +++ b/src/Data/OpenUnion.hs @@ -1,91 +1,142 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -Wall #-} module Data.OpenUnion where import Control.Monad.Discount.Effect -import Unsafe.Coerce +import Data.Typeable + +data Dict c where Dict :: c => Dict c + + +data Nat = Z | S Nat + deriving Typeable + +data SNat :: Nat -> * where + SZ :: SNat 'Z + SS :: Typeable n => SNat n -> SNat ('S n) + deriving Typeable + + +type family IndexOf (ts :: [k]) (n :: Nat) :: k where + IndexOf (k ': ks) 'Z = k + IndexOf (k ': ks) ('S n) = IndexOf ks n + + +type family Found (ts :: [k]) (t :: k) :: Nat where + Found (t ': ts) t = 'Z + Found (u ': ts) t = 'S (Found ts t) + + +class Typeable (Found r t) => Find (r :: [k]) (t :: k) where + finder :: SNat (Found r t) + +instance {-# OVERLAPPING #-} Find (t ': z) t where + finder = SZ + {-# INLINE finder #-} + +instance ( Find z t + , Found (_1 ': z) t ~ 'S (Found z t) + ) => Find (_1 ': z) t where + finder = SS $ finder @_ @z @t + {-# INLINE finder #-} + + data Union (r :: [(* -> *) -> * -> *]) (m :: * -> *) a where - Union :: Effect e => Word -> e m a -> Union r m a + Union + :: Effect (IndexOf r n) + => SNat n + -> IndexOf r n m a + -> Union r m a + + +decomp :: Union (e ': r) m a -> Either (Union r m a) (e m a) +decomp (Union p a) = + case p of + SZ -> Right a + SS n -> Left $ Union n a +{-# INLINE decomp #-} extract :: Union '[e] m a -> e m a -extract (Union _ a) = unsafeCoerce a +extract (Union SZ a) = a +extract _ = error "impossible" {-# INLINE extract #-} -absurdU :: Union '[] m a -> b -absurdU = error "absurd, empty union" - - -unsafeInj :: Effect e => Word -> e m a -> Union r m a -unsafeInj w = Union w -{-# INLINE unsafeInj #-} - - -unsafePrj :: Word -> Union r m a -> Maybe (t m a) -unsafePrj n (Union n' x) - | n == n' = Just (unsafeCoerce x) - | otherwise = Nothing -{-# INLINE unsafePrj #-} - -newtype P t r = P {unP :: Word} - - -class FindElem (t :: k) (r :: [k]) where - elemNo :: P t r - - -instance FindElem t (t ': r) where - elemNo = P 0 - {-# INLINE elemNo #-} - - -instance {-# OVERLAPPABLE #-} FindElem t r => FindElem t (t' ': r) where - elemNo = P $ 1 + unP (elemNo :: P t r) - {-# INLINE elemNo #-} - - -class (FindElem e r, Effect e) - => Member (e :: (* -> *) -> * -> *) - (r :: [(* -> *) -> * -> *]) where - inj :: Monad m => e m a -> Union r m a - prj :: Union r m a -> Maybe (e m a) - - -instance (Effect t, FindElem t r) => Member t r where - inj = unsafeInj $ unP (elemNo :: P t r) - {-# INLINE inj #-} - - prj = unsafePrj $ unP (elemNo :: P t r) - {-# INLINE prj #-} - - -decomp :: Union (t ': r) m a -> Either (Union r m a) (e m a) -decomp (Union 0 a) = Right $ unsafeCoerce a -decomp (Union n a) = Left $ Union (n - 1) a -{-# INLINE [2] decomp #-} - - weaken :: Union r m a -> Union (e ': r) m a -weaken (Union n a) = Union (n + 1) a +weaken (Union n a) = + case induceTypeable n of + Dict -> Union (SS n) a +{-# INLINE weaken #-} + + +inj + :: forall r e a m + . ( Functor m + , Find r e + , Effect e + , e ~ IndexOf r (Found r e) + ) + => e m a + -> Union r m a +inj e = Union (finder @_ @r @e) e +{-# INLINE inj #-} + + +induceTypeable :: SNat n -> Dict (Typeable n) +induceTypeable SZ = Dict +induceTypeable (SS _) = Dict +{-# INLINE induceTypeable #-} + + +type Member e r = (Find r e, e ~ IndexOf r (Found r e), Effect e) + + + +prj + :: forall r e a m + . ( Find r e + , e ~ IndexOf r (Found r e) + ) + => Union r m a + -> Maybe (e m a) +prj (Union (s :: SNat n) a) = + case induceTypeable s of + Dict -> + case eqT @n @(Found r e) of + Just Refl -> Just a + Nothing -> Nothing +{-# INLINE prj #-} instance Effect (Union r) where weave s f (Union w e) = Union w $ weave s f e {-# INLINE weave #-} -instance Functor m => Functor (Union r m) where - fmap f (Union w t) = Union w $ fmap f t + +instance (Functor m) => Functor (Union r m) where + fmap f (Union w t) = Union w $ fmap' f t + where + fmap' :: (Functor m, Effect f) => (a -> b) -> f m a -> f m b + fmap' = fmap + {-# INLINE fmap' #-} {-# INLINE fmap #-} diff --git a/src/TRYAGAIN.hs b/src/TRYAGAIN.hs index 4a30b27..a065eef 100644 --- a/src/TRYAGAIN.hs +++ b/src/TRYAGAIN.hs @@ -31,10 +31,12 @@ data State s m a get :: Member (State s) r => Eff r s get = send $ Get id +{-# INLINE get #-} put :: Member (State s) r => s -> Eff r () put s = send $ Put s () +{-# INLINE put #-} data Error e m a @@ -85,6 +87,7 @@ runRelayS pure' bind' = flip go $ fmap (uncurry (flip id)) $ weave (s', ()) (uncurry $ flip go) x Right eff -> bind' eff +{-# INLINE runRelayS #-} runError :: Eff (Error e ': r) a -> Eff r (Either e a) diff --git a/src/Wtf.hs b/src/Wtf.hs index fd32b85..3769fb6 100644 --- a/src/Wtf.hs +++ b/src/Wtf.hs @@ -3,10 +3,11 @@ module Wtf where +import Control.Monad.Discount import TRYAGAIN import Data.Functor.Identity -go :: Eff '[State Int, Lift (Identity)] Int +go :: Eff '[State Int] Int go = do n <- send (Get id) if n <= 0 diff --git a/test/Spec.hs b/test/Spec.hs index cd4753f..bdd08ed 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,2 +1,26 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TemplateHaskell #-} + +import Test.Inspection +import Control.Monad.Discount +import Data.OpenUnion +import TRYAGAIN hiding (main) + + main :: IO () -main = putStrLn "Test suite not yet implemented" +main = pure () + +go :: Eff '[State Int] Int +go = do + n <- send (Get id) + if n <= 0 + then pure n + else do + send $ Put (n-1) () + go + +countDown :: Int -> Int +countDown start = fst $ run $ runState start go + +inspect $ 'countDown `hasNoType` ''SNat +