Support lists in specifications.

This commit is contained in:
Maciej Bendkowski 2022-02-24 19:47:40 +01:00
parent a66aa915cd
commit 5eb9efea3e
8 changed files with 202 additions and 42 deletions

View File

@ -110,6 +110,33 @@ executable lambdaProfile
, vector >=0.12.3.1
default-language: Haskell2010
executable treeProfile
main-is: TreeProfile.hs
other-modules:
Tree
Paths_generic_boltzmann_brain
hs-source-dirs:
profile/Tree
default-extensions:
NumericUnderscores LambdaCase BangPatterns DerivingVia FlexibleInstances UndecidableInstances TypeApplications ScopedTypeVariables Rank2Types
ghc-options: -O2 -Wall -Wcompat -Wmissing-export-lists -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints -Wno-name-shadowing -fwarn-missing-signatures -ddump-splices
build-depends:
QuickCheck >=2.14.2
, base >=4.7 && <5
, containers >=0.6.4
, generic-boltzmann-brain
, mtl >=2.2.2
, paganini-hs >=0.3.0.0
, random >=1.2.0
, splitmix >=0.1.0.4
, template-haskell >=2.17.0.0
, th-abstraction >=0.4.3.0
, th-lift >=0.8.2
, th-lift-instances >=0.1.18
, transformers >=0.5.6
, vector >=0.12.3.1
default-language: Haskell2010
test-suite generic-boltzmann-brain-test
type: exitcode-stdio-1.0
main-is: Spec.hs

View File

@ -82,6 +82,32 @@ executables:
TypeApplications
ScopedTypeVariables
Rank2Types
treeProfile:
main: TreeProfile.hs
source-dirs: profile/Tree
ghc-options:
- -O2
- -Wall
- -Wcompat
- -Wmissing-export-lists
- -Wincomplete-record-updates
- -Wincomplete-uni-patterns
- -Wredundant-constraints
- -Wno-name-shadowing
- -fwarn-missing-signatures
- -ddump-splices
dependencies:
- generic-boltzmann-brain
default-extensions:
NumericUnderscores
LambdaCase
BangPatterns
DerivingVia
FlexibleInstances
UndecidableInstances
TypeApplications
ScopedTypeVariables
Rank2Types
lambdaProfile:
main: LambdaProfile.hs
source-dirs: profile/Lambda

28
profile/Tree/Tree.hs Normal file
View File

@ -0,0 +1,28 @@
{-# LANGUAGE TemplateHaskell #-}
module Tree (Tree (..), randomTreeListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.Samplable (Distribution (..), Samplable (..))
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (System (..))
import Data.Boltzmann.System.TH (mkSystemBoltzmannSampler)
import Data.BuffonMachine (evalIO)
import System.Random.SplitMix (SMGen)
data Tree = T [Tree]
deriving (Show)
mkSystemBoltzmannSampler
System
{ targetType = ''Tree
, meanSize = 1000
, frequencies = []
, weights =
[ ('T, 1)
]
}
randomTreeListIO :: Int -> IO [Tree]
randomTreeListIO n =
evalIO $ replicateM n (rejectionSampler' @SMGen 1000 0.2)

View File

@ -0,0 +1,4 @@
import Tree (randomTreeListIO)
main :: IO ()
main = randomTreeListIO 100 >>= print

View File

@ -4,20 +4,20 @@ module Data.Boltzmann.Samplable.TH (mkSamplable) where
import Control.Monad (forM, void)
import Data.Boltzmann.System (
Distributions (..),
System (..),
Types (..),
collectTypes,
paganiniSpecIO,
)
import qualified Data.Map as Map
import Data.Map.Strict (Map)
import Data.Set (Set)
import qualified Data.Set as Set
import Control.Monad (forM_, unless)
import Data.Boltzmann.Samplable (Distribution)
import Language.Haskell.TH (Q, runIO)
import Language.Haskell.TH.Datatype (DatatypeInfo)
import Language.Haskell.TH.Syntax (
Body (NormalB),
Clause (Clause),
@ -27,7 +27,7 @@ import Language.Haskell.TH.Syntax (
Phases (AllPhases),
Pragma (InlineP),
RuleMatch (FunLike),
Type (AppT, ConT),
Type (AppT, ConT, ListT),
mkName,
)
@ -39,11 +39,11 @@ import Language.Haskell.TH.Datatype (
sysDistributions ::
System ->
Map Name DatatypeInfo ->
IO (Map Name (Distribution a))
Types ->
IO (Distributions a)
sysDistributions sys types = do
spec <- paganiniSpecIO sys types
return $ case spec of
pure $ case spec of
Left err -> error (show err)
Right x -> x
@ -56,8 +56,8 @@ hasAdmissibleFrequencies sys = do
constructorNames :: System -> Q (Set Name)
constructorNames sys = do
types <- collectTypes sys
foldMap constructorNames' (Map.keysSet types)
Types regTypes _ <- collectTypes sys
foldMap constructorNames' (Map.keysSet regTypes)
constructorNames' :: Name -> Q (Set Name)
constructorNames' typ = do
@ -70,9 +70,9 @@ mkSamplable sys = do
void $ hasAdmissibleFrequencies sys
types <- collectTypes sys
distrMap <- runIO $ sysDistributions sys types
Distributions regTypeDdgs listTypeDdgs <- runIO $ sysDistributions sys types
forM (Map.toList distrMap) $ \(typ, d) -> do
ts <- forM (Map.toList regTypeDdgs) $ \(typ, d) -> do
distribution <- [|d|]
let cls = AppT (ConT $ mkName "Samplable") (ConT typ)
constrName = mkName "constrDistribution"
@ -83,3 +83,17 @@ mkSamplable sys = do
[Clause [] (NormalB distribution) []]
, PragmaD $ InlineP constrName Inline FunLike AllPhases
]
ls <- forM (Map.toList listTypeDdgs) $ \(typ, d) -> do
distribution <- [|d|]
let cls = AppT (ConT $ mkName "Samplable") (AppT ListT $ ConT typ)
constrName = mkName "constrDistribution"
pure $
InstanceD Nothing [] cls $
[ FunD
constrName
[Clause [] (NormalB distribution) []]
, PragmaD $ InlineP constrName Inline FunLike AllPhases
]
pure $ ts <> ls

View File

@ -12,17 +12,32 @@ module Data.Boltzmann.Sampler (
hoistBoltzmannSampler,
) where
import Control.Monad (guard)
import Control.Monad.Trans.Maybe (MaybeT, runMaybeT)
import Control.Monad.Trans (lift)
import Data.BuffonMachine (BuffonMachine, eval)
import System.Random (RandomGen)
import Test.QuickCheck (Gen)
import Test.QuickCheck.Gen (Gen (MkGen))
import Test.QuickCheck.Random (QCGen (QCGen))
import Data.Boltzmann.Samplable (Samplable (..), Distribution, choice)
-- | Multiparametric Boltzmann samplers.
class BoltzmannSampler a where
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int)
instance (Samplable [a], BoltzmannSampler a) => BoltzmannSampler [a] where
sample !ub = do
guard (ub > 0)
(lift $ choice (constrDistribution :: Distribution [a]))
>>= ( \case
0 -> pure ([], 0)
_ -> do
(x, w) <- sample ub
(xs, ws) <- sample (ub - w)
pure (x : xs, w + ws))
rejectionSampler ::
(RandomGen g, BoltzmannSampler a) => Int -> Int -> BuffonMachine g a
rejectionSampler lb ub = do

View File

@ -15,6 +15,7 @@ import Data.Boltzmann.Samplable (Samplable (constrDistribution), choice)
import Data.Boltzmann.Sampler (sample)
import Data.Boltzmann.System (
System,
Types (..),
collectTypes,
getWeight,
)
@ -196,8 +197,8 @@ mkBoltzmannSampler' sys typ = do
mkBoltzmannSampler :: System -> Q [Dec]
mkBoltzmannSampler sys = do
types <- collectTypes sys
decls <- forM (Map.toList types) $ \(typ, _) -> do
Types regTypes _ <- collectTypes sys
decls <- forM (Map.toList regTypes) $ \(typ, _) -> do
mkBoltzmannSampler' sys typ
pure $ concat decls

View File

@ -1,18 +1,21 @@
module Data.Boltzmann.System (
Types (..),
Distributions (..),
collectTypes,
System (..),
getWeight,
paganiniSpecIO,
) where
import Language.Haskell.TH.Syntax (Name, Type (ConT))
import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT))
import Control.Monad (foldM, replicateM)
import Control.Monad (foldM, forM, replicateM)
import Data.Boltzmann.Samplable (Distribution (Distribution))
import qualified Data.Map as Map
import Data.Map.Strict (Map)
import Data.Maybe (fromJust, fromMaybe)
import Data.Paganini (
Def (Def),
Expr,
FromVariable,
Let (Let),
@ -20,6 +23,7 @@ import Data.Paganini (
Spec,
ddg,
debugPaganini,
seq,
tune,
variable,
variable',
@ -35,6 +39,8 @@ import Language.Haskell.TH.Datatype (
reifyDatatype,
)
import Prelude hiding (seq)
data System = System
{ targetType :: Name
, meanSize :: Int
@ -47,34 +53,50 @@ getWeight :: System -> Name -> Int
getWeight sys name =
fromMaybe 1 $ lookup name (weights sys)
collectTypes :: System -> Q (Map Name DatatypeInfo)
data Types = Types
{ regTypes :: Map Name DatatypeInfo
, listTypes :: Set Name
}
initTypes :: Types
initTypes = Types Map.empty Set.empty
collectTypes :: System -> Q Types
collectTypes sys = do
info <- reifyDatatype $ targetType sys
collectFromDataTypeInfo Map.empty info
collectFromDataTypeInfo initTypes info
collectFromDataTypeInfo ::
Map Name DatatypeInfo ->
Types ->
DatatypeInfo ->
Q (Map Name DatatypeInfo)
Q Types
collectFromDataTypeInfo types info =
case name `Map.lookup` types of
case name `Map.lookup` (regTypes types) of
Just _ -> pure types
Nothing -> foldM collectTypesFromCons types' (datatypeCons info)
where
types' = Map.insert name info types
types' = types {regTypes = regTypes'}
regTypes' = Map.insert name info (regTypes types)
name = datatypeName info
collectTypesFromCons ::
Map Name DatatypeInfo ->
Types ->
ConstructorInfo ->
Q (Map Name DatatypeInfo)
Q Types
collectTypesFromCons types consInfo =
foldM collectFromType types (constructorFields consInfo)
collectFromType :: Map Name DatatypeInfo -> Type -> Q (Map Name DatatypeInfo)
collectFromType :: Types -> Type -> Q Types
collectFromType types typ =
case typ of
ConT t -> reifyDatatype t >>= collectFromDataTypeInfo types
AppT ListT (ConT t) -> do
info <- reifyDatatype t
let types' = types {listTypes = listTypes'}
listTypes' = Set.insert (datatypeName info) (listTypes types)
collectFromDataTypeInfo types' info
_ -> fail $ "Unsupported type " ++ show typ
mkVariables :: Set Name -> Spec (Map Name Let)
@ -84,6 +106,15 @@ mkVariables sys = do
let sys' = Set.toList sys
pure (Map.fromList $ sys' `zip` xs)
mkListVariables :: Set Name -> Map Name Let -> Spec (Map Name Def)
mkListVariables listTypes varDefs = do
ps <- forM (Set.toList listTypes) $ \lt -> do
let (Let v) = varDefs Map.! lt
s <- seq v
pure (lt, s)
pure $ Map.fromList ps
mkMarkingVariables :: System -> Spec (Map Name Let)
mkMarkingVariables sys = do
xs <-
@ -99,6 +130,7 @@ mkMarkingVariables sys = do
data Params = Params
{ sizeVar :: forall a. FromVariable a => a
, typeVariable :: Map Name Let
, listVariable :: Map Name Def
, markingVariable :: Map Name Let
, system :: System
}
@ -131,44 +163,57 @@ argExpr :: Params -> Type -> Expr
argExpr params typ =
case typ of
ConT t -> let Let x = typeVariable params Map.! t in x
AppT ListT (ConT t) -> let Def x = listVariable params Map.! t in x
_ -> error $ "Absurd type " ++ show typ
defaults :: (Num p, FromVariable p) => Maybe Let -> p
defaults Nothing = 1
defaults (Just (Let x)) = x
mkDidtributions :: Params -> Spec (Map Name (Distribution a))
mkDidtributions params = do
let typeList = Map.toList $ typeVariable params
ddgs <-
mapM
( \(n, x) -> do
ddgTree <- ddg x
return (n, Distribution $ fromList $ fromJust ddgTree)
)
typeList
data Distributions a = Distributions
{ regTypeDdgs :: Map Name (Distribution a)
, listTypeDdgs :: Map Name (Distribution a)
}
return $ Map.fromList ddgs
mkDidtributions :: Params -> Spec (Distributions a)
mkDidtributions params = do
let mkDistribution = Distribution . fromList . fromJust
regDdgs <- forM (Map.toList $ typeVariable params) $ \(name, v) -> do
ddgTree <- ddg v
pure (name, mkDistribution ddgTree)
listDdgs <- forM (Map.toList $ listVariable params) $ \(name, v) -> do
ddgTree <- ddg v
pure (name, mkDistribution ddgTree)
pure $
Distributions
{ regTypeDdgs = Map.fromList regDdgs
, listTypeDdgs = Map.fromList listDdgs
}
paganiniSpec ::
System ->
Map Name DatatypeInfo ->
Spec (Map Name (Distribution a))
paganiniSpec sys types = do
Types ->
Spec (Distributions a)
paganiniSpec sys (Types regTypes listTypes) = do
let n = meanSize sys
Let z <- variable' $ fromIntegral n
varDefs <- mkVariables (Map.keysSet types)
varDefs <- mkVariables (Map.keysSet regTypes)
listDefs <- mkListVariables listTypes varDefs
markDefs <- mkMarkingVariables sys
let params =
Params
{ sizeVar = z
, typeVariable = varDefs
, listVariable = listDefs
, markingVariable = markDefs
, system = sys
}
mkTypeVariables params types
mkTypeVariables params regTypes
let (Let t) = varDefs Map.! targetType sys
tune t -- tune for target variable.
@ -176,6 +221,6 @@ paganiniSpec sys types = do
paganiniSpecIO ::
System ->
Map Name DatatypeInfo ->
IO (Either PaganiniError (Map Name (Distribution a)))
Types ->
IO (Either PaganiniError (Distributions a))
paganiniSpecIO sys types = debugPaganini $ paganiniSpec sys types