Restrict type-level nats to [0,37]. (#3139)

* Restrict type-level nats to [0,37]

* Update compiler/daml-lf-ast/src/DA/Daml/LF/Ast/TypeLevelNat.hs

Co-Authored-By: Martin Huschenbett <martin.huschenbett@posteo.me>
This commit is contained in:
associahedron 2019-10-10 13:20:24 +01:00 committed by mergify[bot]
parent a125860a9d
commit 172996e4db
12 changed files with 122 additions and 32 deletions

View File

@ -7,6 +7,7 @@ module DA.Daml.LF.Ast
) where
import DA.Daml.LF.Ast.Base as LF
import DA.Daml.LF.Ast.TypeLevelNat as LF
import DA.Daml.LF.Ast.Util as LF
import DA.Daml.LF.Ast.Version as LF
import DA.Daml.LF.Ast.World as LF

View File

@ -13,7 +13,6 @@ module DA.Daml.LF.Ast.Base(
import Data.Hashable
import Data.Data
import Numeric.Natural
import GHC.Generics(Generic)
import Data.Int
import Control.DeepSeq
@ -26,6 +25,7 @@ import qualified Control.Lens.TH as Lens.TH
import DA.Daml.LF.Ast.Version
import DA.Daml.LF.Ast.Numeric
import DA.Daml.LF.Ast.TypeLevelNat
infixr 1 `KArrow`
@ -179,7 +179,7 @@ data Type
-- fields and their types.
| TTuple ![(FieldName, Type)]
-- | Type-level natural numbers
| TNat !Natural
| TNat !TypeLevelNat
deriving (Eq, Data, Generic, NFData, Ord, Show)
-- | Fully applied qualified type constructor.

View File

@ -18,7 +18,6 @@ module DA.Daml.LF.Ast.Optics(
builtinType
) where
import Numeric.Natural
import Control.Lens
import Control.Lens.Ast
import Control.Lens.MonoTraversal
@ -26,6 +25,7 @@ import Data.Functor.Foldable (cata, embed)
import qualified Data.NameMap as NM
import DA.Daml.LF.Ast.Base
import DA.Daml.LF.Ast.TypeLevelNat
import DA.Daml.LF.Ast.Recursive
import DA.Daml.LF.Ast.Version (Version)
@ -140,7 +140,7 @@ instance MonoTraversable ModuleRef BuiltinExpr where monoTraverse _ = pure
-- discussion
instance MonoTraversable ModuleRef SourceLoc where monoTraverse _ = pure
instance MonoTraversable ModuleRef Natural where monoTraverse _ = pure
instance MonoTraversable ModuleRef TypeLevelNat where monoTraverse _ = pure
instance MonoTraversable ModuleRef TypeConApp
instance MonoTraversable ModuleRef Type

View File

@ -19,6 +19,7 @@ import qualified Data.Time.Format as Time.Format
import Data.Foldable (toList)
import DA.Daml.LF.Ast.Base hiding (dataCons)
import DA.Daml.LF.Ast.TypeLevelNat
import DA.Daml.LF.Ast.Util
import DA.Daml.LF.Ast.Optics
import DA.Pretty hiding (keyword_, type_)
@ -166,7 +167,7 @@ instance Pretty Type where
(prettyForall <-> hsep (map (prettyAndKind lvl) vs) <> "."
<-> pPrintPrec lvl precTForall t1)
TTuple fields -> prettyTuple lvl prettyHasType fields
TNat n -> integer (fromIntegral n)
TNat n -> integer (fromTypeLevelNat n)
precEApp, precEAbs :: Rational
precEApp = 2

View File

@ -0,0 +1,76 @@
-- Copyright (c) 2019 The DAML Authors. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE PatternSynonyms #-}
-- | Representation of DAML-LF type-level naturals.
module DA.Daml.LF.Ast.TypeLevelNat
( TypeLevelNat
, TypeLevelNatError (..)
, pattern TypeLevelNat10
, fromTypeLevelNat
, typeLevelNatE
, typeLevelNat
) where
import Control.DeepSeq
import Data.Data
import Data.Hashable
import Data.Maybe
import Numeric.Natural
import GHC.Generics (Generic)
-- | A type-level natural. For now these are restricted to being between
-- 0 and 37 (inclusive). We do not expose the constructor of this type
-- to prevent the construction of values outside of that bound.
newtype TypeLevelNat
= TypeLevelNat { unTypeLevelNat :: Int }
deriving newtype (Eq, NFData, Ord, Show, Hashable)
deriving (Data, Generic)
data TypeLevelNatError
= TLNEOutOfBounds
deriving (Eq, Ord, Show)
instance Bounded TypeLevelNat where
minBound = TypeLevelNat 0
maxBound = TypeLevelNat 37
fromTypeLevelNat :: Num b => TypeLevelNat -> b
fromTypeLevelNat = fromIntegral . unTypeLevelNat
-- | Construct a type-level natural in a safe way.
typeLevelNatE :: Integral a => a -> Either TypeLevelNatError TypeLevelNat
typeLevelNatE n'
| n < fromTypeLevelNat minBound || n > fromTypeLevelNat maxBound = Left TLNEOutOfBounds
| otherwise = Right $ TypeLevelNat (fromIntegral n)
where
n = fromIntegral n' :: Integer
-- | Construct a type-level natural. Raises an error if the number is out of bounds.
typeLevelNat :: Integral a => a -> TypeLevelNat
typeLevelNat m =
case typeLevelNatE m of
Left TLNEOutOfBounds -> error . concat $
[ "type-level nat is out of bounds: "
, show (fromIntegral m :: Integer)
, " not in [0, "
, show (maxBound @TypeLevelNat)
, "]"
]
Right n -> n
pattern TypeLevelNat10 :: TypeLevelNat
pattern TypeLevelNat10 = TypeLevelNat 10
instance Read TypeLevelNat where
readsPrec p = mapMaybe postProcess . readsPrec p
where
postProcess :: (Natural, String) -> Maybe (TypeLevelNat, String)
postProcess (m, xs) =
case typeLevelNatE m of
Left _ -> Nothing
Right n -> Just (n, xs)

View File

@ -15,6 +15,7 @@ import Data.List.Extra (nubSort)
import qualified Data.NameMap as NM
import DA.Daml.LF.Ast.Base
import DA.Daml.LF.Ast.TypeLevelNat
import DA.Daml.LF.Ast.Optics
import DA.Daml.LF.Ast.Recursive
@ -153,12 +154,13 @@ infixr 1 :->
pattern (:->) :: Type -> Type -> Type
pattern a :-> b = TArrow `TApp` a `TApp` b
pattern TUnit, TBool, TInt64, TDecimal, TText, TTimestamp, TParty, TDate, TArrow, TNumeric10, TAny :: Type
pattern TUnit, TBool, TInt64, TDecimal, TText, TTimestamp, TParty, TDate, TArrow, TNumeric10, TAny, TNat10 :: Type
pattern TUnit = TBuiltin BTUnit
pattern TBool = TBuiltin BTBool
pattern TInt64 = TBuiltin BTInt64
pattern TDecimal = TBuiltin BTDecimal -- legacy decimal (LF version <= 1.6)
pattern TNumeric10 = TNumeric (TNat 10) -- new decimal
pattern TNumeric10 = TNumeric TNat10 -- new decimal
pattern TNat10 = TNat TypeLevelNat10
pattern TText = TBuiltin BTText
pattern TTimestamp = TBuiltin BTTimestamp
pattern TParty = TBuiltin BTParty

View File

@ -606,6 +606,7 @@ decodeNumericLit (T.unpack -> str) = case readMaybe str of
Nothing -> throwError $ ParseError $ "bad Numeric literal: " ++ show str
Just n -> pure $ BENumeric n
decodeKind :: LF1.Kind -> Decode Kind
decodeKind LF1.Kind{..} = mayDecode "kindSum" kindSum $ \case
LF1.KindSumStar LF1.Unit -> pure KStar
@ -634,13 +635,19 @@ decodePrim = pure . \case
LF1.PrimTypeARROW -> BTArrow
LF1.PrimTypeANY -> BTAny
decodeTypeLevelNat :: Integer -> Decode TypeLevelNat
decodeTypeLevelNat m =
case typeLevelNatE m of
Left TLNEOutOfBounds ->
throwError $ ParseError $ "bad type-level nat: " <> show m <> " is out of bounds"
Right n ->
pure n
decodeType :: LF1.Type -> Decode Type
decodeType LF1.Type{..} = mayDecode "typeSum" typeSum $ \case
LF1.TypeSumVar (LF1.Type_Var var args) ->
decodeWithArgs args $ TVar <$> decodeName TypeVarName var
LF1.TypeSumNat n ->
pure $ TNat (fromIntegral n)
-- TODO (#2289): determine if some bounds check should be made here.
LF1.TypeSumNat n -> TNat <$> decodeTypeLevelNat (fromIntegral n)
LF1.TypeSumCon (LF1.Type_Con mbCon args) ->
decodeWithArgs args $ TCon <$> mayDecode "type_ConTycon" mbCon decodeTypeConName
LF1.TypeSumPrim (LF1.Type_Prim (Proto.Enumerated (Right prim)) args) -> do

View File

@ -223,8 +223,7 @@ encodeType' typ = fmap (P.Type . Just) $ case typ ^. _TApps of
pure $ P.TypeSumTuple P.Type_Tuple{..}
(TNat n, _) ->
pure $ P.TypeSumNat (fromIntegral n)
-- TODO (#2289): determine if some bounds check should be made here
pure $ P.TypeSumNat (fromTypeLevelNat n)
(TApp{}, _) -> error "TApp after unwinding TApp"
-- NOTE(MH): The following case is ill-kinded.

View File

@ -145,7 +145,7 @@ typeOfBuiltin :: MonadGamma m => BuiltinExpr -> m Type
typeOfBuiltin = \case
BEInt64 _ -> pure TInt64
BEDecimal _ -> pure TDecimal
BENumeric n -> pure (TNumeric (TNat (numericScale n)))
BENumeric n -> pure (TNumeric (TNat (typeLevelNat (numericScale n))))
BEText _ -> pure TText
BETimestamp _ -> pure TTimestamp
BEParty _ -> pure TParty

View File

@ -55,8 +55,8 @@ serializabilityConditionsType world0 _version mbModNameTpls vars = go
TOptional typ -> go typ
TMap typ -> go typ
TNumeric (TNat n)
| n <= numericMaxScale -> noConditions
| otherwise -> Left (URNumericOutOfRange n)
| fromTypeLevelNat n <= numericMaxScale -> noConditions
| otherwise -> Left (URNumericOutOfRange (fromTypeLevelNat n))
TNumeric _ -> Left URNumericNotFixed
-- We statically enforce bounds check for Numeric type,
-- requiring 0 <= n <= 'numericMaxScale' for the argument

View File

@ -1192,7 +1192,7 @@ convertTyCon env t
"Numeric" -> pure (TBuiltin BTNumeric)
"Decimal" ->
if envLfVersion env `supports` featureNumeric
then pure (TNumeric (TNat 10))
then pure TNumeric10
else pure TDecimal
_ -> defaultTyCon
-- TODO(DEL-6953): We need to add a condition on the package name as well.
@ -1246,8 +1246,12 @@ convertType env t | Just t' <- getTyVar_maybe t
= TVar . fst <$> convTypeVar t'
convertType env t | Just s <- isStrLitTy t
= pure TUnit
convertType env t | Just n <- isNumLitTy t, n >= 0
= pure (TNat (fromIntegral n))
convertType env t | Just m <- isNumLitTy t
= case typeLevelNatE m of
Left TLNEOutOfBounds ->
unsupported "type-level natural outside of supported range [0, 37]" m
Right n ->
pure (TNat n)
convertType env t | Just (a,b) <- splitAppTy_maybe t
= TApp <$> convertType env a <*> convertType env b
convertType env x

View File

@ -160,33 +160,33 @@ convertPrim _ "BECoerceContractId" (TContractId a :-> TContractId b) =
-- in the type) but Decimal primitives are still used (from the
-- stdlib). Eventually the Decimal primitives will be phased out.
convertPrim _ "BEAddDecimal" (TNumeric10 :-> TNumeric10 :-> TNumeric10) =
ETyApp (EBuiltin BEAddNumeric) (TNat 10)
ETyApp (EBuiltin BEAddNumeric) TNat10
convertPrim _ "BESubDecimal" (TNumeric10 :-> TNumeric10 :-> TNumeric10) =
ETyApp (EBuiltin BESubNumeric) (TNat 10)
ETyApp (EBuiltin BESubNumeric) TNat10
convertPrim _ "BEMulDecimal" (TNumeric10 :-> TNumeric10 :-> TNumeric10) =
ETyApp (ETyApp (ETyApp (EBuiltin BEMulNumeric) (TNat 10)) (TNat 10)) (TNat 10)
EBuiltin BEMulNumeric `ETyApp` TNat10 `ETyApp` TNat10 `ETyApp` TNat10
convertPrim _ "BEDivDecimal" (TNumeric10 :-> TNumeric10 :-> TNumeric10) =
ETyApp (ETyApp (ETyApp (EBuiltin BEDivNumeric) (TNat 10)) (TNat 10)) (TNat 10)
EBuiltin BEDivNumeric `ETyApp` TNat10 `ETyApp` TNat10 `ETyApp` TNat10
convertPrim _ "BERoundDecimal" (TInt64 :-> TNumeric10 :-> TNumeric10) =
ETyApp (EBuiltin BERoundNumeric) (TNat 10)
ETyApp (EBuiltin BERoundNumeric) TNat10
convertPrim _ "BEEqual" (TNumeric10 :-> TNumeric10 :-> TBool) =
ETyApp (EBuiltin BEEqualNumeric) (TNat 10)
ETyApp (EBuiltin BEEqualNumeric) TNat10
convertPrim _ "BELess" (TNumeric10 :-> TNumeric10 :-> TBool) =
ETyApp (EBuiltin BELessNumeric) (TNat 10)
ETyApp (EBuiltin BELessNumeric) TNat10
convertPrim _ "BELessEq" (TNumeric10 :-> TNumeric10 :-> TBool) =
ETyApp (EBuiltin BELessEqNumeric) (TNat 10)
ETyApp (EBuiltin BELessEqNumeric) TNat10
convertPrim _ "BEGreaterEq" (TNumeric10 :-> TNumeric10 :-> TBool) =
ETyApp (EBuiltin BEGreaterEqNumeric) (TNat 10)
ETyApp (EBuiltin BEGreaterEqNumeric) TNat10
convertPrim _ "BEGreater" (TNumeric10 :-> TNumeric10 :-> TBool) =
ETyApp (EBuiltin BEGreaterNumeric) (TNat 10)
ETyApp (EBuiltin BEGreaterNumeric) TNat10
convertPrim _ "BEInt64ToDecimal" (TInt64 :-> TNumeric10) =
ETyApp (EBuiltin BEInt64ToNumeric) (TNat 10)
ETyApp (EBuiltin BEInt64ToNumeric) TNat10
convertPrim _ "BEDecimalToInt64" (TNumeric10 :-> TInt64) =
ETyApp (EBuiltin BENumericToInt64) (TNat 10)
ETyApp (EBuiltin BENumericToInt64) TNat10
convertPrim _ "BEToText" (TNumeric10 :-> TText) =
ETyApp (EBuiltin BEToTextNumeric) (TNat 10)
ETyApp (EBuiltin BEToTextNumeric) TNat10
convertPrim _ "BEDecimalFromText" (TText :-> TOptional TNumeric10) =
ETyApp (EBuiltin BENumericFromText) (TNat 10)
ETyApp (EBuiltin BENumericFromText) TNat10
-- Numeric primitives. These are polymorphic in the scale.
convertPrim _ "BEAddNumeric" (TNumeric n1 :-> TNumeric n2 :-> TNumeric n3) | n1 == n2, n1 == n3 =