1
1
mirror of https://github.com/anoma/juvix.git synced 2025-01-05 22:46:08 +03:00

Fix JuvixTree type unification (#2972)

* Closes #2954 
* The problem was that the type validation algorithm was too strict for
higher-order functions with a dynamic (unknown) target.
This commit is contained in:
Łukasz Czajka 2024-08-27 10:31:14 +02:00 committed by GitHub
parent 9c980d152a
commit eb5b2e4595
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 174 additions and 123 deletions

View File

@ -150,7 +150,7 @@ checkValueStack' loc tab tys mem = do
mapM_
( \(ty, idx) -> do
let ty' = fromJust $ topValueStack idx mem
unless (isSubtype' ty' ty) $
unless (isSubtype ty' ty) $
throw $
AsmError loc $
"type mismatch on value stack cell "

View File

@ -39,127 +39,127 @@ curryType ty = case typeArgs ty of
in foldr (\tyarg ty'' -> mkTypeFun [tyarg] ty'') (typeTarget ty') tyargs
isSubtype :: Type -> Type -> Bool
isSubtype ty1 ty2 = case (ty1, ty2) of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False
isSubtype' :: Type -> Type -> Bool
isSubtype' ty1 ty2
-- The guard is to ensure correct behaviour with dynamic type targets. E.g.
-- `A -> B -> C -> D` should be a subtype of `(A, B) -> *`.
| tgt1 == TyDynamic || tgt2 == TyDynamic =
isSubtype
(curryType ty1)
(curryType ty2)
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
isSubtype' ty1 ty2 =
isSubtype ty1 ty2
isSubtype ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False
unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
unifyTypes ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
where
err :: Sem r a
err = do

View File

@ -3,6 +3,7 @@ module Tree.Asm.Base where
import Asm.Run.Base qualified as Asm
import Base
import Juvix.Compiler.Asm.Translation.FromTree qualified as Asm
import Juvix.Compiler.Tree.Pipeline qualified as Tree
import Juvix.Compiler.Tree.Translation.FromSource
import Juvix.Data.PPOutput
@ -18,5 +19,8 @@ treeAsmAssertion mainFile expectedFile step = do
Left err -> assertFailure (prettyString err)
Right tabIni -> do
step "Translate"
let tab = Asm.fromTree tabIni
Asm.asmRunAssertion' tab expectedFile step
case run $ runError @JuvixError $ Tree.toAsm tabIni of
Left err -> assertFailure (prettyString (fromJuvixError @GenericError err))
Right tab -> do
let tab' = Asm.fromTree tab
Asm.asmRunAssertion' tab' expectedFile step

View File

@ -239,5 +239,10 @@ tests =
"Test040: ByteArray"
$(mkRelDir ".")
$(mkRelFile "test040.jvt")
$(mkRelFile "out/test040.out")
$(mkRelFile "out/test040.out"),
PosTest
"Test041: Type unification"
$(mkRelDir ".")
$(mkRelFile "test041.jvt")
$(mkRelFile "out/test041.out")
]

View File

@ -0,0 +1 @@
0

View File

@ -0,0 +1,41 @@
type Foldable {
mkFoldable : ((* → * → *) → * → * → *) → Foldable;
}
type Box {
mkBox : * → Box;
}
function lambda_16(integer, integer) : integer;
function lambda_18((integer, integer) → integer, integer, Box) : integer;
function foldableBoxintegerI() : Foldable;
function go_17(integer) : integer;
function main() : integer;
function lambda_16(_X : integer, _X' : integer) : integer {
_X'
}
function lambda_18(f : (integer, integer) → integer, ini : integer, _X : Box) : integer {
case[Box](_X) {
mkBox: save {
call[go_17](tmp[0].mkBox[0])
}
}
}
function foldableBoxintegerI() : Foldable {
alloc[mkFoldable](calloc[lambda_18]())
}
function go_17(x' : integer) : integer {
x'
}
function main() : integer {
case[Foldable](call[foldableBoxintegerI]()) {
mkFoldable: save {
ccall(tmp[0].mkFoldable[0], calloc[lambda_16](), 0, alloc[mkBox](0))
}
}
}