diff --git a/bench/bench.hs b/bench/bench.hs index 9f5d06b..6d1a5e1 100644 --- a/bench/bench.hs +++ b/bench/bench.hs @@ -1,26 +1,55 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} import Criterion.Main +import Grenade + import Grenade.Layers.Internal.Convolution import Grenade.Layers.Internal.Pooling import Numeric.LinearAlgebra main :: IO () -main = defaultMain [ - bgroup "im2col" [ bench "im2col 3x4" $ whnf (im2col 2 2 1 1) ((3><4) [1..]) - , bench "im2col 28x28" $ whnf (im2col 5 5 1 1) ((28><28) [1..]) - , bench "im2col 100x100" $ whnf (im2col 10 10 1 1) ((100><100) [1..]) - ] - , bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2im 2 2 1 1 3 4) ((6><4) [1..]) - , bench "col2im 28x28" $ whnf (col2im 5 5 1 1 28 28) ((576><25) [1..]) - , bench "col2im 100x100" $ whnf (col2im 10 10 1 1 100 100) ((8281><100) [1..]) - ] - , bgroup "poolfw" [ bench "poolforwards 3x4" $ whnf (poolForward 1 3 4 2 2 1 1) ((3><4) [1..]) - , bench "poolforwards 28x28" $ whnf (poolForward 1 28 28 5 5 1 1) ((28><28) [1..]) - , bench "poolforwards 100x100" $ whnf (poolForward 1 100 100 10 10 1 1) ((100><100) [1..]) - ] - , bgroup "poolbw" [ bench "poolbackwards 3x4" $ whnf (poolBackward 1 3 4 2 2 1 1 ((3><4) [1..])) ((2><3) [1..]) - , bench "poolbackwards 28x28" $ whnf (poolBackward 1 28 28 5 5 1 1 ((28><28) [1..])) ((24><24) [1..]) - , bench "poolbackwards 100x100" $ whnf (poolBackward 1 100 100 10 10 1 1 ((100><100) [1..])) ((91><91) [1..]) - ] - ] +main = do + x :: S ('D2 60 60 ) <- randomOfShape + y :: S ('D3 60 60 1) <- randomOfShape + + defaultMain [ + bgroup "im2col" [ bench "im2col 3x4" $ whnf (im2col 2 2 1 1) ((3><4) [1..]) + , bench "im2col 28x28" $ whnf (im2col 5 5 1 1) ((28><28) [1..]) + , bench "im2col 100x100" $ whnf (im2col 10 10 1 1) ((100><100) [1..]) + ] + , bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2im 2 2 1 1 3 4) ((6><4) [1..]) + , bench "col2im 28x28" $ whnf (col2im 5 5 1 1 28 28) ((576><25) [1..]) + , bench "col2im 100x100" $ whnf (col2im 10 10 1 1 100 100) ((8281><100) [1..]) + ] + , bgroup "poolfw" [ bench "poolforwards 3x4" $ whnf (poolForward 1 3 4 2 2 1 1) ((3><4) [1..]) + , bench "poolforwards 28x28" $ whnf (poolForward 1 28 28 5 5 1 1) ((28><28) [1..]) + , bench "poolforwards 100x100" $ whnf (poolForward 1 100 100 10 10 1 1) ((100><100) [1..]) + ] + , bgroup "poolbw" [ bench "poolbackwards 3x4" $ whnf (poolBackward 1 3 4 2 2 1 1 ((3><4) [1..])) ((2><3) [1..]) + , bench "poolbackwards 28x28" $ whnf (poolBackward 1 28 28 5 5 1 1 ((28><28) [1..])) ((24><24) [1..]) + , bench "poolbackwards 100x100" $ whnf (poolBackward 1 100 100 10 10 1 1 ((100><100) [1..])) ((91><91) [1..]) + ] + , bgroup "padcrop" [ bench "pad 2D 60x60" $ whnf (testRun2D Pad) x + , bench "pad 3D 60x60" $ whnf (testRun3D Pad) y + , bench "crop 2D 60x60" $ whnf (testRun2D' Crop) x + , bench "crop 3D 60x60" $ whnf (testRun3D' Crop) y + ] + ] + + +testRun2D :: Pad 1 1 1 1 -> S ('D2 60 60) -> S ('D2 62 62) +testRun2D = snd ... runForwards + +testRun3D :: Pad 1 1 1 1 -> S ('D3 60 60 1) -> S ('D3 62 62 1) +testRun3D = snd ... runForwards + +testRun2D' :: Crop 1 1 1 1 -> S ('D2 60 60) -> S ('D2 58 58) +testRun2D' = snd ... runForwards + +testRun3D' :: Crop 1 1 1 1 -> S ('D3 60 60 1) -> S ('D3 58 58 1) +testRun3D' = snd ... runForwards + +(...) :: (a -> b) -> (c -> d -> a) -> c -> d -> b +(...) = (.) . (.) diff --git a/cbits/pad.c b/cbits/pad.c index 0dde7cc..3fc7b1a 100644 --- a/cbits/pad.c +++ b/cbits/pad.c @@ -1,23 +1,21 @@ #include "pad.h" -void pad_cpu(const double* data, const int channels, +void pad_cpu(double* data, const int channels, const int height, const int width, const int pad_left, const int pad_top, const int pad_right, const int pad_bottom, double* data_padded) { const int pad_width = width + pad_left + pad_right; const int pad_height = height + pad_top + pad_bottom; - const int channel_size = height * width; memset(data_padded, 0, pad_height * pad_width * channels * sizeof(double)); for (int channel = 0; channel < channels; channel++) { double* px = data_padded + (pad_width * pad_top + pad_left) + channel * (pad_width * pad_height); for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - *(px++) = data[y * width + x + channel_size * channel]; - } - px += pad_left + pad_right; + memcpy(px, data, sizeof(double) * width); + px += pad_width; + data += width; } } } @@ -30,15 +28,12 @@ void crop_cpu(double* data, const int channels, const int crop_width = width + crop_left + crop_right; const int crop_height = height + crop_top + crop_bottom; - const int channel_size = height * width; - for (int channel = 0; channel < channels; channel++) { double* px = data + (crop_width * crop_top + crop_left) + channel * (crop_width * crop_height); for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - data_cropped[y * width + x + channel_size * channel] = *(px++); - } - px += crop_left + crop_right; + memcpy(data_cropped, px, sizeof(double) * width); + px += crop_width; + data_cropped += width; } } } diff --git a/cbits/pad.h b/cbits/pad.h index 434bf50..c16aa09 100644 --- a/cbits/pad.h +++ b/cbits/pad.h @@ -2,7 +2,7 @@ #include #include -void pad_cpu(const double* data_im, const int channels, +void pad_cpu(double* data_im, const int channels, const int height, const int width, const int pad_left, const int pad_top, const int pad_right, const int pad_bottom, double* data_col); diff --git a/grenade.cabal b/grenade.cabal index 12689f8..e89fe71 100644 --- a/grenade.cabal +++ b/grenade.cabal @@ -22,22 +22,19 @@ library build-depends: base >= 4.8 && < 5 , bytestring == 0.10.* - , async , containers , deepseq , either == 4.4.* , cereal , exceptions == 0.8.* - , hmatrix + , hmatrix == 0.18.* , MonadRandom , mtl >= 2.2.1 && < 2.3 - , parallel == 3.2.* , primitive - , reflection , text == 1.2.* , transformers - , singletons - , vector + , singletons >= 2.1 && < 2.3 + , vector == 0.11.* ghc-options: -Wall @@ -100,7 +97,7 @@ executable feedforward , bytestring , cereal , either - , optparse-applicative == 0.12.* + , optparse-applicative == 0.13.* , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix @@ -115,7 +112,7 @@ executable mnist , grenade , attoparsec , either - , optparse-applicative == 0.12.* + , optparse-applicative == 0.13.* , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix >= 0.18 && < 0.19 @@ -131,7 +128,7 @@ executable recurrent , grenade , attoparsec , either - , optparse-applicative == 0.12.* + , optparse-applicative == 0.13.* , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix >= 0.18 && < 0.19 @@ -149,7 +146,7 @@ executable shakespeare , bytestring , cereal , either - , optparse-applicative == 0.12.* + , optparse-applicative == 0.13.* , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix >= 0.18 && < 0.19 diff --git a/main/feedforward.hs b/main/feedforward.hs index 9b5a59d..eb93694 100644 --- a/main/feedforward.hs +++ b/main/feedforward.hs @@ -10,6 +10,7 @@ import Data.List ( foldl' ) import qualified Data.ByteString as B import Data.Serialize +import Data.Semigroup ( (<>) ) import GHC.TypeLits diff --git a/main/mnist.hs b/main/mnist.hs index 2815518..cbb2844 100644 --- a/main/mnist.hs +++ b/main/mnist.hs @@ -12,6 +12,7 @@ import Control.Monad.Trans.Except import qualified Data.Attoparsec.Text as A import Data.List ( foldl' ) +import Data.Semigroup ( (<>) ) import qualified Data.Text as T import qualified Data.Text.IO as T import qualified Data.Vector.Storable as V @@ -24,7 +25,6 @@ import Options.Applicative import Grenade import Grenade.Utils.OneHot - -- The definition of our convolutional neural network. -- In the type signature, we have a type level list of shapes which are passed between the layers. -- One can see that the images we are inputing are two dimensional with 28 * 28 pixels. diff --git a/main/recurrent.hs b/main/recurrent.hs index bdc86f4..57a0f3a 100644 --- a/main/recurrent.hs +++ b/main/recurrent.hs @@ -14,6 +14,7 @@ import Data.List ( unfoldr ) #else import Data.List ( cycle, unfoldr ) #endif +import Data.Semigroup ( (<>) ) import qualified Numeric.LinearAlgebra.Static as SA diff --git a/main/shakespeare.hs b/main/shakespeare.hs index 83b1a0d..8be214d 100644 --- a/main/shakespeare.hs +++ b/main/shakespeare.hs @@ -13,6 +13,7 @@ import Control.Monad.Trans.Except import Data.Char ( isUpper, toUpper, toLower ) import Data.List ( foldl' ) import Data.Maybe ( fromMaybe ) +import Data.Semigroup ( (<>) ) import qualified Data.Vector as V import Data.Vector ( Vector ) diff --git a/src/Grenade/Core/Layer.hs b/src/Grenade/Core/Layer.hs index 58b6e36..c43cb9a 100644 --- a/src/Grenade/Core/Layer.hs +++ b/src/Grenade/Core/Layer.hs @@ -18,7 +18,7 @@ module Grenade.Core.Layer ( , UpdateLayer (..) ) where -import Control.Monad.Random (MonadRandom) +import Control.Monad.Random ( MonadRandom ) import Data.List ( foldl' ) diff --git a/src/Grenade/Layers/Crop.hs b/src/Grenade/Layers/Crop.hs index a2abe4e..837b8a9 100644 --- a/src/Grenade/Layers/Crop.hs +++ b/src/Grenade/Layers/Crop.hs @@ -82,18 +82,30 @@ instance ( KnownNat cropLeft , (outputColumns + cropLeft + cropRight) ~ inputColumns ) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where type Tape (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = () - runForwards Crop input = - let cropl = Proxy :: Proxy cropLeft - cropt = Proxy :: Proxy cropTop - cropr = Proxy :: Proxy cropRight - cropb = Proxy :: Proxy cropBottom - cropped = crop cropl cropt cropr cropb input - in ((), cropped) + runForwards Crop (S3D input) = + let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) + padt = fromIntegral $ natVal (Proxy :: Proxy cropTop) + padr = fromIntegral $ natVal (Proxy :: Proxy cropRight) + padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) + inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) + inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) + outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) + outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) + ch = fromIntegral $ natVal (Proxy :: Proxy channels) + m = extract input + cropped = crop ch padl padt padr padb outr outc inr inc m + in ((), S3D . fromJust . create $ cropped) - runBackwards Crop () gradient = - let cropl = Proxy :: Proxy cropLeft - cropt = Proxy :: Proxy cropTop - cropr = Proxy :: Proxy cropRight - cropb = Proxy :: Proxy cropBottom - padded = pad cropl cropt cropr cropb gradient - in ((), padded) + runBackwards Crop () (S3D gradient) = + let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) + padt = fromIntegral $ natVal (Proxy :: Proxy cropTop) + padr = fromIntegral $ natVal (Proxy :: Proxy cropRight) + padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) + inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) + inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) + outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) + outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) + ch = fromIntegral $ natVal (Proxy :: Proxy channels) + m = extract gradient + padded = pad ch padl padt padr padb outr outc inr inc m + in ((), S3D . fromJust . create $ padded) diff --git a/src/Grenade/Layers/Internal/Pad.hs b/src/Grenade/Layers/Internal/Pad.hs index f6368c6..e611098 100644 --- a/src/Grenade/Layers/Internal/Pad.hs +++ b/src/Grenade/Layers/Internal/Pad.hs @@ -1,123 +1,53 @@ -{-# LANGUAGE CPP #-} {-# LANGUAGE ForeignFunctionInterface #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE FlexibleContexts #-} - -#if __GLASGOW_HASKELL__ == 800 -{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -#endif module Grenade.Layers.Internal.Pad ( pad , crop ) where -import Data.Maybe ( fromJust ) -import Data.Proxy import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) -import GHC.TypeLits - -import Grenade.Core - import Foreign ( mallocForeignPtrArray, withForeignPtr ) import Foreign.Ptr ( Ptr ) -import Numeric.LinearAlgebra ( flatten ) -import Numeric.LinearAlgebra.Static ( extract ) +import Numeric.LinearAlgebra ( flatten, Matrix ) +import qualified Numeric.LinearAlgebra.Devel as U import System.IO.Unsafe ( unsafePerformIO ) -pad :: forall padLeft padTop padRight padBottom rows rows' cols cols' channels. - ( KnownNat padLeft - , KnownNat padTop - , KnownNat padRight - , KnownNat padBottom - , KnownNat rows - , KnownNat rows' - , KnownNat cols - , KnownNat cols' - , KnownNat channels - , rows' ~ (rows + padTop + padBottom) - , cols' ~ (cols + padLeft + padRight) - , KnownNat (rows' * channels) - ) => Proxy padLeft - -> Proxy padTop - -> Proxy padRight - -> Proxy padBottom - -> S ('D3 rows cols channels) - -> S ('D3 rows' cols' channels) -pad _ _ _ _ (S3D m) = - let channels = fromIntegral $ natVal (Proxy :: Proxy channels) - padLeft = fromIntegral $ natVal (Proxy :: Proxy padLeft) - padTop = fromIntegral $ natVal (Proxy :: Proxy padTop) - padRight = fromIntegral $ natVal (Proxy :: Proxy padRight) - padBottom = fromIntegral $ natVal (Proxy :: Proxy padBottom) - rows = fromIntegral $ natVal (Proxy :: Proxy rows) - cols = fromIntegral $ natVal (Proxy :: Proxy cols) - rows' = fromIntegral $ natVal (Proxy :: Proxy rows') - cols' = fromIntegral $ natVal (Proxy :: Proxy cols') - outMatSize = rows' * cols' * channels +pad :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double +pad channels padLeft padTop padRight padBottom rows cols rows' cols' m + = let outMatSize = rows' * cols' * channels + vec = flatten m + in unsafePerformIO $ do + outPtr <- mallocForeignPtrArray outMatSize + let (inPtr, _) = U.unsafeToForeignPtr0 vec - vec = flatten (extract m) - in unsafePerformIO $ do - outPtr <- mallocForeignPtrArray outMatSize - let (inPtr, _) = U.unsafeToForeignPtr0 vec + withForeignPtr inPtr $ \inPtr' -> + withForeignPtr outPtr $ \outPtr' -> + pad_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr' - withForeignPtr inPtr $ \inPtr' -> - withForeignPtr outPtr $ \outPtr' -> - pad_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr' - - let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize - return (fromJust $ fromStorable matVec) + let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize + return (U.matrixFromVector U.RowMajor (rows' * channels) cols' matVec) +{-# INLINE pad #-} foreign import ccall unsafe pad_cpu :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () +crop :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double +crop channels padLeft padTop padRight padBottom rows cols _ _ m + = let outMatSize = rows * cols * channels + vec = flatten m + in unsafePerformIO $ do + outPtr <- mallocForeignPtrArray outMatSize + let (inPtr, _) = U.unsafeToForeignPtr0 vec -crop :: forall padLeft padTop padRight padBottom rows rows' cols cols' channels. - ( KnownNat padLeft - , KnownNat padTop - , KnownNat padRight - , KnownNat padBottom - , KnownNat rows - , KnownNat cols - , KnownNat cols' - , KnownNat channels - , rows' ~ (rows + padTop + padBottom) - , cols' ~ (cols + padLeft + padRight) - , KnownNat (rows * channels) - ) => Proxy padLeft - -> Proxy padTop - -> Proxy padRight - -> Proxy padBottom - -> S ('D3 rows' cols' channels) - -> S ('D3 rows cols channels) -crop _ _ _ _ (S3D m) = - let channels = fromIntegral $ natVal (Proxy :: Proxy channels) - padLeft = fromIntegral $ natVal (Proxy :: Proxy padLeft) - padTop = fromIntegral $ natVal (Proxy :: Proxy padTop) - padRight = fromIntegral $ natVal (Proxy :: Proxy padRight) - padBottom = fromIntegral $ natVal (Proxy :: Proxy padBottom) - rows = fromIntegral $ natVal (Proxy :: Proxy rows) - cols = fromIntegral $ natVal (Proxy :: Proxy cols) - outMatSize = rows * cols * channels + withForeignPtr inPtr $ \inPtr' -> + withForeignPtr outPtr $ \outPtr' -> + crop_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr' - vec = flatten (extract m) - in unsafePerformIO $ do - outPtr <- mallocForeignPtrArray outMatSize - let (inPtr, _) = U.unsafeToForeignPtr0 vec - - withForeignPtr inPtr $ \inPtr' -> - withForeignPtr outPtr $ \outPtr' -> - crop_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr' - - let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize - return (fromJust $ fromStorable matVec) + let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize + return (U.matrixFromVector U.RowMajor (rows * channels) cols matVec) foreign import ccall unsafe crop_cpu diff --git a/src/Grenade/Layers/Pad.hs b/src/Grenade/Layers/Pad.hs index bf5ad4c..5908ee9 100644 --- a/src/Grenade/Layers/Pad.hs +++ b/src/Grenade/Layers/Pad.hs @@ -70,7 +70,6 @@ instance ( KnownNat padLeft vs = subMatrix (padt, padl) (nrows, ncols) m in ((), S2D . fromJust . create $ vs) - -- | A two dimentional image can be padped. instance ( KnownNat padLeft , KnownNat padTop @@ -87,18 +86,30 @@ instance ( KnownNat padLeft , (inputColumns + padLeft + padRight) ~ outputColumns ) => Layer (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where type Tape (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = () - runForwards Pad input = - let padl = Proxy :: Proxy padLeft - padt = Proxy :: Proxy padTop - padr = Proxy :: Proxy padRight - padb = Proxy :: Proxy padBottom - padded = pad padl padt padr padb input - in ((), padded) + runForwards Pad (S3D input) = + let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) + padt = fromIntegral $ natVal (Proxy :: Proxy padTop) + padr = fromIntegral $ natVal (Proxy :: Proxy padRight) + padb = fromIntegral $ natVal (Proxy :: Proxy padBottom) + outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) + outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) + inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) + inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) + ch = fromIntegral $ natVal (Proxy :: Proxy channels) + m = extract input + padded = pad ch padl padt padr padb inr inc outr outc m + in ((), S3D . fromJust . create $ padded) - runBackwards Pad () gradient = - let padl = Proxy :: Proxy padLeft - padt = Proxy :: Proxy padTop - padr = Proxy :: Proxy padRight - padb = Proxy :: Proxy padBottom - cropped = crop padl padt padr padb gradient - in ((), cropped) + runBackwards Pad () (S3D gradient) = + let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) + padt = fromIntegral $ natVal (Proxy :: Proxy padTop) + padr = fromIntegral $ natVal (Proxy :: Proxy padRight) + padb = fromIntegral $ natVal (Proxy :: Proxy padBottom) + outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) + outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) + inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) + inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) + ch = fromIntegral $ natVal (Proxy :: Proxy channels) + m = extract gradient + cropped = crop ch padl padt padr padb inr inc outr outc m + in ((), S3D . fromJust . create $ cropped) diff --git a/test/Test/Jack/TypeLits.hs b/test/Test/Jack/TypeLits.hs index 385949b..73f08b6 100644 --- a/test/Test/Jack/TypeLits.hs +++ b/test/Test/Jack/TypeLits.hs @@ -29,23 +29,26 @@ genShape , genD2 , genD3 ] - where - genD1 = do - n <- genNat - return $ case n of - SomeNat (_ :: Proxy x) -> SomeSing (sing :: Sing ('D1 x)) - genD2 = do - n <- genNat - m <- genNat - return $ case (n, m) of - (SomeNat (_ :: Proxy x), SomeNat (_ :: Proxy y)) -> SomeSing (sing :: Sing ('D2 x y)) +genD1 :: Jack (SomeSing Shape) +genD1 = do + n <- genNat + return $ case n of + SomeNat (_ :: Proxy x) -> SomeSing (sing :: Sing ('D1 x)) - genD3 = do - n <- genNat - m <- genNat - o <- genNat - return $ case (n, m, o) of - (SomeNat (px :: Proxy x), SomeNat (_ :: Proxy y), SomeNat (pz :: Proxy z)) -> - case natDict px %* natDict pz of - Dict -> SomeSing (sing :: Sing ('D3 x y z)) +genD2 :: Jack (SomeSing Shape) +genD2 = do + n <- genNat + m <- genNat + return $ case (n, m) of + (SomeNat (_ :: Proxy x), SomeNat (_ :: Proxy y)) -> SomeSing (sing :: Sing ('D2 x y)) + +genD3 :: Jack (SomeSing Shape) +genD3 = do + n <- genNat + m <- genNat + o <- genNat + return $ case (n, m, o) of + (SomeNat (px :: Proxy x), SomeNat (_ :: Proxy y), SomeNat (pz :: Proxy z)) -> + case natDict px %* natDict pz of + Dict -> SomeSing (sing :: Sing ('D3 x y z))