mirror of
https://github.com/anoma/juvix.git
synced 2025-01-05 22:46:08 +03:00
Fix JuvixTree type unification (#2972)
* Closes #2954 * The problem was that the type validation algorithm was too strict for higher-order functions with a dynamic (unknown) target.
This commit is contained in:
parent
9c980d152a
commit
eb5b2e4595
@ -150,7 +150,7 @@ checkValueStack' loc tab tys mem = do
|
||||
mapM_
|
||||
( \(ty, idx) -> do
|
||||
let ty' = fromJust $ topValueStack idx mem
|
||||
unless (isSubtype' ty' ty) $
|
||||
unless (isSubtype ty' ty) $
|
||||
throw $
|
||||
AsmError loc $
|
||||
"type mismatch on value stack cell "
|
||||
|
@ -39,127 +39,127 @@ curryType ty = case typeArgs ty of
|
||||
in foldr (\tyarg ty'' -> mkTypeFun [tyarg] ty'') (typeTarget ty') tyargs
|
||||
|
||||
isSubtype :: Type -> Type -> Bool
|
||||
isSubtype ty1 ty2 = case (ty1, ty2) of
|
||||
(TyDynamic, _) -> True
|
||||
(_, TyDynamic) -> True
|
||||
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
|
||||
_typeConstrInductive == _typeInductiveSymbol
|
||||
(TyConstr c1, TyConstr c2) ->
|
||||
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
|
||||
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
|
||||
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
|
||||
(TyFun t1, TyFun t2) ->
|
||||
let l1 = toList (t1 ^. typeFunArgs)
|
||||
l2 = toList (t2 ^. typeFunArgs)
|
||||
r1 = t1 ^. typeFunTarget
|
||||
r2 = t2 ^. typeFunTarget
|
||||
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
|
||||
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
|
||||
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
|
||||
where
|
||||
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
|
||||
checkBounds _ Nothing Nothing = True
|
||||
checkBounds _ Nothing (Just _) = False
|
||||
checkBounds _ (Just _) Nothing = True
|
||||
checkBounds cmp (Just x) (Just y) = cmp x y
|
||||
(TyBool {}, TyBool {}) -> True
|
||||
(TyString, TyString) -> True
|
||||
(TyField, TyField) -> True
|
||||
(TyByteArray, TyByteArray) -> True
|
||||
(TyUnit, TyUnit) -> True
|
||||
(TyVoid, TyVoid) -> True
|
||||
(TyInductive {}, TyInductive {}) -> ty1 == ty2
|
||||
(TyUnit, _) -> False
|
||||
(_, TyUnit) -> False
|
||||
(TyVoid, _) -> False
|
||||
(_, TyVoid) -> False
|
||||
(TyInteger {}, _) -> False
|
||||
(_, TyInteger {}) -> False
|
||||
(TyString, _) -> False
|
||||
(_, TyString) -> False
|
||||
(TyField, _) -> False
|
||||
(_, TyField) -> False
|
||||
(TyByteArray, _) -> False
|
||||
(_, TyByteArray) -> False
|
||||
(TyBool {}, _) -> False
|
||||
(_, TyBool {}) -> False
|
||||
(TyFun {}, _) -> False
|
||||
(_, TyFun {}) -> False
|
||||
(_, TyConstr {}) -> False
|
||||
|
||||
isSubtype' :: Type -> Type -> Bool
|
||||
isSubtype' ty1 ty2
|
||||
-- The guard is to ensure correct behaviour with dynamic type targets. E.g.
|
||||
-- `A -> B -> C -> D` should be a subtype of `(A, B) -> *`.
|
||||
| tgt1 == TyDynamic || tgt2 == TyDynamic =
|
||||
isSubtype
|
||||
(curryType ty1)
|
||||
(curryType ty2)
|
||||
where
|
||||
tgt1 = typeTarget (uncurryType ty1)
|
||||
tgt2 = typeTarget (uncurryType ty2)
|
||||
isSubtype' ty1 ty2 =
|
||||
isSubtype ty1 ty2
|
||||
isSubtype ty1 ty2 =
|
||||
let (ty1', ty2') =
|
||||
if
|
||||
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
|
||||
(curryType ty1, curryType ty2)
|
||||
| otherwise ->
|
||||
(ty1, ty2)
|
||||
in case (ty1', ty2') of
|
||||
(TyDynamic, _) -> True
|
||||
(_, TyDynamic) -> True
|
||||
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
|
||||
_typeConstrInductive == _typeInductiveSymbol
|
||||
(TyConstr c1, TyConstr c2) ->
|
||||
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
|
||||
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
|
||||
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
|
||||
(TyFun t1, TyFun t2) ->
|
||||
let l1 = toList (t1 ^. typeFunArgs)
|
||||
l2 = toList (t2 ^. typeFunArgs)
|
||||
r1 = t1 ^. typeFunTarget
|
||||
r2 = t2 ^. typeFunTarget
|
||||
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
|
||||
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
|
||||
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
|
||||
where
|
||||
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
|
||||
checkBounds _ Nothing Nothing = True
|
||||
checkBounds _ Nothing (Just _) = False
|
||||
checkBounds _ (Just _) Nothing = True
|
||||
checkBounds cmp (Just x) (Just y) = cmp x y
|
||||
(TyBool {}, TyBool {}) -> True
|
||||
(TyString, TyString) -> True
|
||||
(TyField, TyField) -> True
|
||||
(TyByteArray, TyByteArray) -> True
|
||||
(TyUnit, TyUnit) -> True
|
||||
(TyVoid, TyVoid) -> True
|
||||
(TyInductive {}, TyInductive {}) -> ty1 == ty2
|
||||
(TyUnit, _) -> False
|
||||
(_, TyUnit) -> False
|
||||
(TyVoid, _) -> False
|
||||
(_, TyVoid) -> False
|
||||
(TyInteger {}, _) -> False
|
||||
(_, TyInteger {}) -> False
|
||||
(TyString, _) -> False
|
||||
(_, TyString) -> False
|
||||
(TyField, _) -> False
|
||||
(_, TyField) -> False
|
||||
(TyByteArray, _) -> False
|
||||
(_, TyByteArray) -> False
|
||||
(TyBool {}, _) -> False
|
||||
(_, TyBool {}) -> False
|
||||
(TyFun {}, _) -> False
|
||||
(_, TyFun {}) -> False
|
||||
(_, TyConstr {}) -> False
|
||||
|
||||
unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
|
||||
unifyTypes ty1 ty2 = case (ty1, ty2) of
|
||||
(TyDynamic, x) -> return x
|
||||
(x, TyDynamic) -> return x
|
||||
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
|
||||
| _typeInductiveSymbol == _typeConstrInductive ->
|
||||
return ty1
|
||||
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
|
||||
(TyConstr c1, TyConstr c2)
|
||||
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
|
||||
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
|
||||
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
|
||||
return $ TyConstr (set typeConstrFields flds c1)
|
||||
(TyConstr c1, TyConstr c2)
|
||||
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
|
||||
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
|
||||
(TyFun t1, TyFun t2)
|
||||
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
|
||||
let args1 = toList (t1 ^. typeFunArgs)
|
||||
args2 = toList (t2 ^. typeFunArgs)
|
||||
tgt1 = t1 ^. typeFunTarget
|
||||
tgt2 = t2 ^. typeFunTarget
|
||||
args <- zipWithM (unifyTypes @t @e) args1 args2
|
||||
tgt <- unifyTypes @t @e tgt1 tgt2
|
||||
return $ TyFun (TypeFun (nonEmpty' args) tgt)
|
||||
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
|
||||
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
|
||||
where
|
||||
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
|
||||
unifyBounds _ Nothing _ = Nothing
|
||||
unifyBounds _ _ Nothing = Nothing
|
||||
unifyBounds f (Just x) (Just y) = Just (f x y)
|
||||
(TyBool {}, TyBool {})
|
||||
| ty1 == ty2 -> return ty1
|
||||
(TyString, TyString) -> return TyString
|
||||
(TyField, TyField) -> return TyField
|
||||
(TyByteArray, TyByteArray) -> return TyByteArray
|
||||
(TyUnit, TyUnit) -> return TyUnit
|
||||
(TyVoid, TyVoid) -> return TyVoid
|
||||
(TyInductive {}, TyInductive {})
|
||||
| ty1 == ty2 -> return ty1
|
||||
(TyUnit, _) -> err
|
||||
(_, TyUnit) -> err
|
||||
(TyVoid, _) -> err
|
||||
(_, TyVoid) -> err
|
||||
(TyInteger {}, _) -> err
|
||||
(_, TyInteger {}) -> err
|
||||
(TyString, _) -> err
|
||||
(_, TyString) -> err
|
||||
(TyField, _) -> err
|
||||
(_, TyField) -> err
|
||||
(TyByteArray, _) -> err
|
||||
(_, TyByteArray) -> err
|
||||
(TyBool {}, _) -> err
|
||||
(_, TyBool {}) -> err
|
||||
(TyFun {}, _) -> err
|
||||
(_, TyFun {}) -> err
|
||||
(TyInductive {}, _) -> err
|
||||
(_, TyConstr {}) -> err
|
||||
unifyTypes ty1 ty2 =
|
||||
let (ty1', ty2') =
|
||||
if
|
||||
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
|
||||
(curryType ty1, curryType ty2)
|
||||
| otherwise ->
|
||||
(ty1, ty2)
|
||||
in case (ty1', ty2') of
|
||||
(TyDynamic, x) -> return x
|
||||
(x, TyDynamic) -> return x
|
||||
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
|
||||
| _typeInductiveSymbol == _typeConstrInductive ->
|
||||
return ty1
|
||||
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
|
||||
(TyConstr c1, TyConstr c2)
|
||||
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
|
||||
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
|
||||
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
|
||||
return $ TyConstr (set typeConstrFields flds c1)
|
||||
(TyConstr c1, TyConstr c2)
|
||||
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
|
||||
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
|
||||
(TyFun t1, TyFun t2)
|
||||
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
|
||||
let args1 = toList (t1 ^. typeFunArgs)
|
||||
args2 = toList (t2 ^. typeFunArgs)
|
||||
tgt1 = t1 ^. typeFunTarget
|
||||
tgt2 = t2 ^. typeFunTarget
|
||||
args <- zipWithM (unifyTypes @t @e) args1 args2
|
||||
tgt <- unifyTypes @t @e tgt1 tgt2
|
||||
return $ TyFun (TypeFun (nonEmpty' args) tgt)
|
||||
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
|
||||
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
|
||||
where
|
||||
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
|
||||
unifyBounds _ Nothing _ = Nothing
|
||||
unifyBounds _ _ Nothing = Nothing
|
||||
unifyBounds f (Just x) (Just y) = Just (f x y)
|
||||
(TyBool {}, TyBool {})
|
||||
| ty1 == ty2 -> return ty1
|
||||
(TyString, TyString) -> return TyString
|
||||
(TyField, TyField) -> return TyField
|
||||
(TyByteArray, TyByteArray) -> return TyByteArray
|
||||
(TyUnit, TyUnit) -> return TyUnit
|
||||
(TyVoid, TyVoid) -> return TyVoid
|
||||
(TyInductive {}, TyInductive {})
|
||||
| ty1 == ty2 -> return ty1
|
||||
(TyUnit, _) -> err
|
||||
(_, TyUnit) -> err
|
||||
(TyVoid, _) -> err
|
||||
(_, TyVoid) -> err
|
||||
(TyInteger {}, _) -> err
|
||||
(_, TyInteger {}) -> err
|
||||
(TyString, _) -> err
|
||||
(_, TyString) -> err
|
||||
(TyField, _) -> err
|
||||
(_, TyField) -> err
|
||||
(TyByteArray, _) -> err
|
||||
(_, TyByteArray) -> err
|
||||
(TyBool {}, _) -> err
|
||||
(_, TyBool {}) -> err
|
||||
(TyFun {}, _) -> err
|
||||
(_, TyFun {}) -> err
|
||||
(TyInductive {}, _) -> err
|
||||
(_, TyConstr {}) -> err
|
||||
where
|
||||
err :: Sem r a
|
||||
err = do
|
||||
|
@ -3,6 +3,7 @@ module Tree.Asm.Base where
|
||||
import Asm.Run.Base qualified as Asm
|
||||
import Base
|
||||
import Juvix.Compiler.Asm.Translation.FromTree qualified as Asm
|
||||
import Juvix.Compiler.Tree.Pipeline qualified as Tree
|
||||
import Juvix.Compiler.Tree.Translation.FromSource
|
||||
import Juvix.Data.PPOutput
|
||||
|
||||
@ -18,5 +19,8 @@ treeAsmAssertion mainFile expectedFile step = do
|
||||
Left err -> assertFailure (prettyString err)
|
||||
Right tabIni -> do
|
||||
step "Translate"
|
||||
let tab = Asm.fromTree tabIni
|
||||
Asm.asmRunAssertion' tab expectedFile step
|
||||
case run $ runError @JuvixError $ Tree.toAsm tabIni of
|
||||
Left err -> assertFailure (prettyString (fromJuvixError @GenericError err))
|
||||
Right tab -> do
|
||||
let tab' = Asm.fromTree tab
|
||||
Asm.asmRunAssertion' tab' expectedFile step
|
||||
|
@ -239,5 +239,10 @@ tests =
|
||||
"Test040: ByteArray"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test040.jvt")
|
||||
$(mkRelFile "out/test040.out")
|
||||
$(mkRelFile "out/test040.out"),
|
||||
PosTest
|
||||
"Test041: Type unification"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test041.jvt")
|
||||
$(mkRelFile "out/test041.out")
|
||||
]
|
||||
|
1
tests/Tree/positive/out/test041.out
Normal file
1
tests/Tree/positive/out/test041.out
Normal file
@ -0,0 +1 @@
|
||||
0
|
41
tests/Tree/positive/test041.jvt
Normal file
41
tests/Tree/positive/test041.jvt
Normal file
@ -0,0 +1,41 @@
|
||||
type Foldable {
|
||||
mkFoldable : ((* → * → *) → * → * → *) → Foldable;
|
||||
}
|
||||
|
||||
type Box {
|
||||
mkBox : * → Box;
|
||||
}
|
||||
|
||||
function lambda_16(integer, integer) : integer;
|
||||
function lambda_18((integer, integer) → integer, integer, Box) : integer;
|
||||
function foldableBoxintegerI() : Foldable;
|
||||
function go_17(integer) : integer;
|
||||
function main() : integer;
|
||||
|
||||
function lambda_16(_X : integer, _X' : integer) : integer {
|
||||
_X'
|
||||
}
|
||||
|
||||
function lambda_18(f : (integer, integer) → integer, ini : integer, _X : Box) : integer {
|
||||
case[Box](_X) {
|
||||
mkBox: save {
|
||||
call[go_17](tmp[0].mkBox[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function foldableBoxintegerI() : Foldable {
|
||||
alloc[mkFoldable](calloc[lambda_18]())
|
||||
}
|
||||
|
||||
function go_17(x' : integer) : integer {
|
||||
x'
|
||||
}
|
||||
|
||||
function main() : integer {
|
||||
case[Foldable](call[foldableBoxintegerI]()) {
|
||||
mkFoldable: save {
|
||||
ccall(tmp[0].mkFoldable[0], calloc[lambda_16](), 0, alloc[mkBox](0))
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user