1
1
mirror of https://github.com/anoma/juvix.git synced 2024-09-17 11:37:11 +03:00

JuvixTree validation (#2616)

* Validation (type checking) of JuvixTree. Similar to JuvixAsm
validation, will help with debugging.
* Depends on #2608
This commit is contained in:
Łukasz Czajka 2024-02-06 15:46:55 +01:00 committed by GitHub
parent 10e2a23239
commit 795212b092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 412 additions and 112 deletions

View File

@ -15,10 +15,13 @@ runCommand opts = do
case Tree.runParser (toFilePath afile) s of
Left err -> exitJuvixError (JuvixError err)
Right tab -> do
tab' <- Tree.applyTransformations (project opts ^. treeReadTransformations) tab
unless (project opts ^. treeReadNoPrint) $
renderStdOut (Tree.ppOutDefault tab' tab')
doEval tab'
r <- runError @JuvixError (Tree.applyTransformations (project opts ^. treeReadTransformations) tab)
case r of
Left err -> exitJuvixError (JuvixError err)
Right tab' -> do
unless (project opts ^. treeReadNoPrint) $
renderStdOut (Tree.ppOutDefault tab' tab')
doEval tab'
where
file :: AppPath File
file = opts ^. treeReadInputFile

View File

@ -168,11 +168,11 @@ unifyMemory' loc tab mem1 mem2 = do
unless (length (mem1 ^. memoryValueStack) == length (mem2 ^. memoryValueStack)) $
throw $
AsmError loc "value stack height mismatch"
vs <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack))
vs <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack))
unless (length (mem1 ^. memoryTempStack) == length (mem2 ^. memoryTempStack)) $
throw $
AsmError loc "temporary stack height mismatch"
ts <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack))
ts <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack))
unless
( length (mem1 ^. memoryArgumentArea) == length (mem2 ^. memoryArgumentArea)
&& mem1 ^. memoryArgsNum == mem2 ^. memoryArgsNum
@ -183,7 +183,7 @@ unifyMemory' loc tab mem1 mem2 = do
args <-
mapM
( \off ->
unifyTypes'
unifyTypes''
loc
tab
(fromJust $ HashMap.lookup off (mem1 ^. memoryArgumentArea))

View File

@ -130,7 +130,7 @@ recurse' sig = go True
checkValueStack' loc (sig ^. recursorInfoTable) tyargs mem
tys <-
zipWithM
(\ty idx -> unifyTypes' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem))
(\ty idx -> unifyTypes'' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem))
tyargs
[0 ..]
return $
@ -226,7 +226,7 @@ recurse' sig = go True
checkValueStack' loc (sig ^. recursorInfoTable) (take argsNum (typeArgs ty)) mem'
let tyargs = topValuesFromValueStack' argsNum mem'
-- `typeArgs ty` may be shorter than `tyargs` only if `ty` is dynamic
zipWithM_ (unifyTypes' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty)
zipWithM_ (unifyTypes'' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty)
return $
pushValueStack (mkTypeFun (drop argsNum (typeArgs ty)) (typeTarget ty)) $
popValueStack argsNum mem'

View File

@ -4,84 +4,18 @@ module Juvix.Compiler.Asm.Extra.Type
)
where
import Data.List.NonEmpty qualified as NonEmpty
import Juvix.Compiler.Asm.Data.InfoTable
import Juvix.Compiler.Asm.Error
import Juvix.Compiler.Asm.Language
import Juvix.Compiler.Asm.Pretty
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Extra.Type
unifyTypes :: forall r. (Members '[Error AsmError, Reader (Maybe Location), Reader InfoTable] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM unifyTypes (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM unifyTypes args1 args2
tgt <- unifyTypes tgt1 tgt2
return $ TyFun (TypeFun (NonEmpty.fromList args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
unifyTypes'' :: forall t e r. (Member (Error AsmError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type
unifyTypes'' loc tab ty1 ty2 = mapError toAsmError $ unifyTypes' loc tab ty1 ty2
where
err :: Sem r a
err = do
loc <- ask
tab <- ask
throw $ AsmError loc ("not unifiable: " <> ppTrace tab ty1 <> ", " <> ppTrace tab ty2)
unifyTypes' :: (Member (Error AsmError) r) => Maybe Location -> InfoTable -> Type -> Type -> Sem r Type
unifyTypes' loc tab ty1 ty2 =
runReader loc $
runReader tab $
-- The `if` is to ensure correct behaviour with dynamic type targets. E.g.
-- `(A, B) -> *` should unify with `A -> B -> C -> D`.
if
| tgt1 == TyDynamic || tgt2 == TyDynamic ->
unifyTypes (curryType ty1) (curryType ty2)
| otherwise ->
unifyTypes ty1 ty2
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
toAsmError :: TreeError -> AsmError
toAsmError TreeError {..} =
AsmError
{ _asmErrorLoc = _treeErrorLoc,
_asmErrorMsg = _treeErrorMsg
}

View File

@ -176,7 +176,7 @@ coreToVampIR' = Core.toStored' >=> storedCoreToVampIR'
-- Other workflows
--------------------------------------------------------------------------------
treeToAsm :: Tree.InfoTable -> Sem r Asm.InfoTable
treeToAsm :: (Member (Error JuvixError) r) => Tree.InfoTable -> Sem r Asm.InfoTable
treeToAsm = Tree.toAsm >=> return . Asm.fromTree
treeToNockma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Tree.InfoTable -> Sem r (Nockma.Cell Natural)

View File

@ -11,6 +11,7 @@ data TransformationId
| Apply
| TempHeight
| FilterUnreachable
| Validate
deriving stock (Data, Bounded, Enum, Show)
data PipelineId
@ -21,10 +22,10 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId
toNockmaTransformations :: [TransformationId]
toNockmaTransformations = [Apply, FilterUnreachable, TempHeight]
toNockmaTransformations = [Validate, Apply, FilterUnreachable, TempHeight]
toAsmTransformations :: [TransformationId]
toAsmTransformations = []
toAsmTransformations = [Validate]
instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
@ -35,6 +36,7 @@ instance TransformationId' TransformationId where
Apply -> strApply
TempHeight -> strTempHeight
FilterUnreachable -> strFilterUnreachable
Validate -> strValidate
instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text

View File

@ -25,3 +25,6 @@ strTempHeight = "temp-height"
strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"
strValidate :: Text
strValidate = "validate"

View File

@ -1,7 +1,15 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Avoid restricted extensions" #-}
{-# HLINT ignore "Avoid restricted flags" #-}
module Juvix.Compiler.Tree.Extra.Type where
import Juvix.Compiler.Tree.Data.InfoTable.Base
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Language.Base
import Juvix.Compiler.Tree.Language.Type
import Juvix.Compiler.Tree.Pretty
mkTypeInteger :: Type
mkTypeInteger = TyInteger (TypeInteger Nothing Nothing)
@ -98,3 +106,78 @@ isSubtype' ty1 ty2
tgt2 = typeTarget (uncurryType ty2)
isSubtype' ty1 ty2 =
isSubtype ty1 ty2
unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
where
err :: Sem r a
err = do
loc <- ask
tab <- ask @(InfoTable' t e)
throw $ TreeError loc ("not unifiable: " <> ppTrace' (defaultOptions tab) ty1 <> ", " <> ppTrace' (defaultOptions tab) ty2)
unifyTypes' :: forall t e r. (Member (Error TreeError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type
unifyTypes' loc tab ty1 ty2 =
runReader loc $
runReader tab $
-- The `if` is to ensure correct behaviour with dynamic type targets. E.g.
-- `(A, B) -> *` should unify with `A -> B -> C -> D`.
if
| tgt1 == TyDynamic || tgt2 == TyDynamic ->
unifyTypes @t @e (curryType ty1) (curryType ty2)
| otherwise ->
unifyTypes @t @e ty1 ty2
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)

View File

@ -7,8 +7,8 @@ where
import Juvix.Compiler.Tree.Data.InfoTable
import Juvix.Compiler.Tree.Transformation
toNockma :: InfoTable -> Sem r InfoTable
toNockma :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable
toNockma = applyTransformations toNockmaTransformations
toAsm :: InfoTable -> Sem r InfoTable
toAsm :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable
toAsm = applyTransformations toAsmTransformations

View File

@ -6,13 +6,15 @@ module Juvix.Compiler.Tree.Transformation
where
import Juvix.Compiler.Tree.Data.TransformationId
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Transformation.Apply
import Juvix.Compiler.Tree.Transformation.Base
import Juvix.Compiler.Tree.Transformation.FilterUnreachable
import Juvix.Compiler.Tree.Transformation.Identity
import Juvix.Compiler.Tree.Transformation.TempHeight
import Juvix.Compiler.Tree.Transformation.Validate
applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable
applyTransformations :: forall r. (Member (Error JuvixError) r) => [TransformationId] -> InfoTable -> Sem r InfoTable
applyTransformations ts tbl = foldM (flip appTrans) tbl ts
where
appTrans :: TransformationId -> InfoTable -> Sem r InfoTable
@ -23,3 +25,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
Apply -> return . computeApply
TempHeight -> return . computeTempHeight
FilterUnreachable -> return . filterUnreachable
Validate -> mapError (JuvixError @TreeError) . validate

View File

@ -0,0 +1,266 @@
module Juvix.Compiler.Tree.Transformation.Validate where
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Extra.Base (getNodeLocation)
import Juvix.Compiler.Tree.Extra.Recursors
import Juvix.Compiler.Tree.Extra.Type
import Juvix.Compiler.Tree.Transformation.Base
inferType :: forall r. (Member (Error TreeError) r) => InfoTable -> FunctionInfo -> Node -> Sem r Type
inferType tab funInfo = goInfer mempty
where
goInfer :: BinderList Type -> Node -> Sem r Type
goInfer bl = \case
Binop x -> goBinop bl x
Unop x -> goUnop bl x
Const x -> goConst bl x
MemRef x -> goMemRef bl x
AllocConstr x -> goAllocConstr bl x
AllocClosure x -> goAllocClosure bl x
ExtendClosure x -> goExtendClosure bl x
Call x -> goCall bl x
CallClosures x -> goCallClosures bl x
Branch x -> goBranch bl x
Case x -> goCase bl x
Save x -> goSave bl x
goBinop :: BinderList Type -> NodeBinop -> Sem r Type
goBinop bl NodeBinop {..} = case _nodeBinopOpcode of
IntAdd -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger
IntSub -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger
IntMul -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger
IntDiv -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger
IntMod -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger
IntLt -> checkBinop mkTypeInteger mkTypeInteger mkTypeBool
IntLe -> checkBinop mkTypeInteger mkTypeInteger mkTypeBool
ValEq -> checkBinop TyDynamic TyDynamic mkTypeBool
StrConcat -> checkBinop TyString TyString TyString
OpSeq -> do
checkType bl _nodeBinopArg1 TyDynamic
goInfer bl _nodeBinopArg2
where
loc = _nodeBinopInfo ^. nodeInfoLocation
checkBinop :: Type -> Type -> Type -> Sem r Type
checkBinop ty1' ty2' rty = do
ty1 <- goInfer bl _nodeBinopArg1
ty2 <- goInfer bl _nodeBinopArg2
void $ unifyTypes' loc tab ty1 ty1'
void $ unifyTypes' loc tab ty2 ty2'
return rty
goUnop :: BinderList Type -> NodeUnop -> Sem r Type
goUnop bl NodeUnop {..} = case _nodeUnopOpcode of
OpShow -> checkUnop TyDynamic TyString
OpStrToInt -> checkUnop TyString mkTypeInteger
OpTrace -> goInfer bl _nodeUnopArg
OpFail -> checkUnop TyDynamic TyDynamic
OpArgsNum -> checkUnop TyDynamic mkTypeInteger
where
loc = _nodeUnopInfo ^. nodeInfoLocation
checkUnop :: Type -> Type -> Sem r Type
checkUnop ty rty = do
ty' <- goInfer bl _nodeUnopArg
void $ unifyTypes' loc tab ty ty'
return rty
goConst :: BinderList Type -> NodeConstant -> Sem r Type
goConst _ NodeConstant {..} = case _nodeConstant of
ConstInt {} -> return mkTypeInteger
ConstBool {} -> return mkTypeBool
ConstString {} -> return TyString
ConstUnit {} -> return TyUnit
ConstVoid {} -> return TyVoid
goMemRef :: BinderList Type -> NodeMemRef -> Sem r Type
goMemRef bl NodeMemRef {..} = case _nodeMemRef of
DRef d -> goDirectRef (_nodeMemRefInfo ^. nodeInfoLocation) bl d
ConstrRef x -> goField bl x
goDirectRef :: Maybe Location -> BinderList Type -> DirectRef -> Sem r Type
goDirectRef loc bl = \case
ArgRef x -> goArgRef loc bl x
TempRef RefTemp {..} -> goTempRef bl _refTempOffsetRef
goArgRef :: Maybe Location -> BinderList Type -> OffsetRef -> Sem r Type
goArgRef loc _ OffsetRef {..}
| _offsetRefOffset < length tys = return $ tys !! _offsetRefOffset
| typeTarget (funInfo ^. functionType) == TyDynamic = return TyDynamic
| otherwise =
throw $
TreeError
{ _treeErrorLoc = loc,
_treeErrorMsg = "Wrong target type"
}
where
tys = typeArgs (funInfo ^. functionType)
goTempRef :: BinderList Type -> OffsetRef -> Sem r Type
goTempRef bl OffsetRef {..} = return $ BL.lookupLevel _offsetRefOffset bl
goField :: BinderList Type -> Field -> Sem r Type
goField _ Field {..}
| _fieldOffset < length tys = return $ tys !! _fieldOffset
| otherwise = return TyDynamic
where
ci = lookupConstrInfo tab _fieldTag
tys = typeArgs (ci ^. constructorType)
goAllocConstr :: BinderList Type -> NodeAllocConstr -> Sem r Type
goAllocConstr bl NodeAllocConstr {..}
| length _nodeAllocConstrArgs == length tys = do
forM_ (zipExact _nodeAllocConstrArgs tys) (uncurry (checkType bl))
return $ typeTarget (ci ^. constructorType)
| otherwise =
throw $
TreeError
{ _treeErrorLoc = _nodeAllocConstrInfo ^. nodeInfoLocation,
_treeErrorMsg = ""
}
where
ci = lookupConstrInfo tab _nodeAllocConstrTag
tys = typeArgs (ci ^. constructorType)
goAllocClosure :: BinderList Type -> NodeAllocClosure -> Sem r Type
goAllocClosure bl NodeAllocClosure {..}
| n <= fi ^. functionArgsNum = do
forM_ (zipExact _nodeAllocClosureArgs (take n tys)) (uncurry (checkType bl))
return $ mkTypeFun (drop n tys) (typeTarget (fi ^. functionType))
| otherwise =
throw $
TreeError
{ _treeErrorLoc = _nodeAllocClosureInfo ^. nodeInfoLocation,
_treeErrorMsg = "Wrong number of arguments"
}
where
n = length _nodeAllocClosureArgs
fi = lookupFunInfo tab _nodeAllocClosureFunSymbol
tys = typeArgs (fi ^. functionType)
goExtendClosure :: BinderList Type -> NodeExtendClosure -> Sem r Type
goExtendClosure bl NodeExtendClosure {..} = do
ty <- goInfer bl _nodeExtendClosureFun
let tys = typeArgs ty
m = length tys
n = length _nodeExtendClosureArgs
if
| n < m -> do
forM_ (zipExact (toList _nodeExtendClosureArgs) (take n tys)) (uncurry (checkType bl))
return $ mkTypeFun (drop n tys) (typeTarget ty)
| typeTarget ty == TyDynamic -> do
let tys' = tys ++ replicate (n - m) TyDynamic
forM_ (zipExact (toList _nodeExtendClosureArgs) tys') (uncurry (checkType bl))
return $ typeTarget ty
| otherwise ->
throw $
TreeError
{ _treeErrorLoc = _nodeExtendClosureInfo ^. nodeInfoLocation,
_treeErrorMsg = "Too many arguments"
}
goCall :: BinderList Type -> NodeCall -> Sem r Type
goCall bl NodeCall {..} = case _nodeCallType of
CallFun sym
| n == fi ^. functionArgsNum -> do
unless (n == 0) $
forM_ (zipExact _nodeCallArgs tys) (uncurry (checkType bl))
return $ mkTypeFun (drop n tys) (typeTarget (fi ^. functionType))
| otherwise ->
throw $
TreeError
{ _treeErrorLoc = _nodeCallInfo ^. nodeInfoLocation,
_treeErrorMsg = "Wrong number of arguments"
}
where
n = length _nodeCallArgs
fi = lookupFunInfo tab sym
tys = typeArgs (fi ^. functionType)
CallClosure cl -> do
ty <- goInfer bl cl
let tys = typeArgs ty
n = length _nodeCallArgs
when (length tys > n) $
throw $
TreeError
{ _treeErrorLoc = _nodeCallInfo ^. nodeInfoLocation,
_treeErrorMsg = "Too few arguments"
}
when (length tys < n && typeTarget ty /= TyDynamic) $
throw $
TreeError
{ _treeErrorLoc = _nodeCallInfo ^. nodeInfoLocation,
_treeErrorMsg = "Too many arguments"
}
let tys' = tys ++ replicate (n - length tys) TyDynamic
forM_ (zipExact _nodeCallArgs tys') (uncurry (checkType bl))
return $ typeTarget ty
goCallClosures :: BinderList Type -> NodeCallClosures -> Sem r Type
goCallClosures bl NodeCallClosures {..} = do
ty <- goInfer bl _nodeCallClosuresFun
go ty (toList _nodeCallClosuresArgs)
where
go :: Type -> [Node] -> Sem r Type
go ty args
| m == 0 =
return ty
| m <= n = do
forM_ (zipExact (take m args) tys) (uncurry (checkType bl))
go (typeTarget ty) (drop m args)
| otherwise = do
forM_ (zipExact args (take n tys)) (uncurry (checkType bl))
return $ mkTypeFun (drop n tys) (typeTarget ty)
where
tys = typeArgs ty
m = length tys
n = length args
goBranch :: BinderList Type -> NodeBranch -> Sem r Type
goBranch bl NodeBranch {..} = do
checkType bl _nodeBranchArg mkTypeBool
ty1 <- goInfer bl _nodeBranchTrue
ty2 <- goInfer bl _nodeBranchFalse
unifyTypes' (_nodeBranchInfo ^. nodeInfoLocation) tab ty1 ty2
goCase :: BinderList Type -> NodeCase -> Sem r Type
goCase bl NodeCase {..} = do
ity <- goInfer bl _nodeCaseArg
unless (ity == mkTypeInductive _nodeCaseInductive || ity == TyDynamic) $
throw $
TreeError
{ _treeErrorLoc = _nodeCaseInfo ^. nodeInfoLocation,
_treeErrorMsg = "Inductive type mismatch"
}
ty <- maybe (return TyDynamic) (goInfer bl) _nodeCaseDefault
go ity ty _nodeCaseBranches
where
go :: Type -> Type -> [CaseBranch] -> Sem r Type
go ity ty = \case
[] -> return ty
CaseBranch {..} : brs -> do
let bl' = if _caseBranchSave then BL.cons ity bl else bl
ty' <- goInfer bl' _caseBranchBody
ty'' <- unifyTypes' (_nodeCaseInfo ^. nodeInfoLocation) tab ty ty'
go ity ty'' brs
goSave :: BinderList Type -> NodeSave -> Sem r Type
goSave bl NodeSave {..} = do
ty <- goInfer bl _nodeSaveArg
goInfer (BL.cons ty bl) _nodeSaveBody
checkType :: BinderList Type -> Node -> Type -> Sem r ()
checkType bl node ty = do
ty' <- goInfer bl node
void $ unifyTypes' (getNodeLocation node) tab ty ty'
validateFunction :: (Member (Error TreeError) r) => InfoTable -> FunctionInfo -> Sem r FunctionInfo
validateFunction tab funInfo = do
ty <- inferType tab funInfo (funInfo ^. functionCode)
let ty' = if funInfo ^. functionArgsNum == 0 then funInfo ^. functionType else typeTarget (funInfo ^. functionType)
void $ unifyTypes' (funInfo ^. functionLocation) tab ty ty'
return funInfo
validate :: (Member (Error TreeError) r) => InfoTable -> Sem r InfoTable
validate tab = mapFunctionsM (validateFunction tab) tab

View File

@ -35,25 +35,31 @@ treeEvalAssertionParam evalParam mainFile expectedFile trans testTrans step = do
case runParser (toFilePath mainFile) s of
Left err -> assertFailure (show (pretty err))
Right tab0 -> do
unless (null trans) $
step "Transform"
let tab = run $ applyTransformations trans tab0
testTrans tab
case tab ^. infoMainFunction of
Just sym -> do
withTempDir'
( \dirPath -> do
let outputFile = dirPath <//> $(mkRelFile "out.out")
hout <- openFile (toFilePath outputFile) WriteMode
step "Evaluate"
evalParam hout sym tab
hClose hout
actualOutput <- readFile (toFilePath outputFile)
step "Compare expected and actual program output"
expected <- readFile (toFilePath expectedFile)
assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected
)
Nothing -> assertFailure "no 'main' function"
step "Validate"
case run $ runError @JuvixError $ applyTransformations [Validate] tab0 of
Left err -> assertFailure (show (pretty (fromJuvixError @GenericError err)))
Right tab1 -> do
unless (null trans) $
step "Transform"
case run $ runError @JuvixError $ applyTransformations trans tab1 of
Left err -> assertFailure (show (pretty (fromJuvixError @GenericError err)))
Right tab -> do
testTrans tab
case tab ^. infoMainFunction of
Just sym -> do
withTempDir'
( \dirPath -> do
let outputFile = dirPath <//> $(mkRelFile "out.out")
hout <- openFile (toFilePath outputFile) WriteMode
step "Evaluate"
evalParam hout sym tab
hClose hout
actualOutput <- readFile (toFilePath outputFile)
step "Compare expected and actual program output"
expected <- readFile (toFilePath expectedFile)
assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected
)
Nothing -> assertFailure "no 'main' function"
evalAssertion :: Handle -> Symbol -> InfoTable -> IO ()
evalAssertion hout sym tab = do

View File

@ -99,7 +99,7 @@ function uncurry(*, * -> *, *) {
tccall 2;
}
function pred_step(Pair) : (* -> *, *) -> * {
function pred_step(Pair) : Pair {
push arg[0].pair[1];
call isZero;
br {

View File

@ -101,7 +101,7 @@ function uncurry(*, * → *, *) : * {
ccall(arg[0], arg[1], arg[2])
}
function pred_step(Pair) : (* → *, *) → * {
function pred_step(Pair) : Pair {
br(call[isZero](arg[0].pair[1])) {
true: alloc[pair](arg[0].pair[0], calloc[uncurry](calloc[succ](arg[0].pair[1])))
false: alloc[pair](calloc[uncurry](calloc[succ](arg[0].pair[0])), calloc[uncurry](calloc[succ](arg[0].pair[1])))