From acea6615a481fa74be421e7d08e17f04efd5bfda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Czajka?= <62751+lukaszcz@users.noreply.github.com> Date: Wed, 25 Jan 2023 18:57:47 +0100 Subject: [PATCH] Lazy boolean operators (#1743) Closes #1701 --- c-runtime/builtins/bool.h | 14 +-- juvix-stdlib | 2 +- .../Abstract/Translation/FromConcrete.hs | 2 + .../Compiler/Backend/C/Data/BuiltinTable.hs | 2 + src/Juvix/Compiler/Backend/C/Data/CNames.hs | 6 + src/Juvix/Compiler/Builtins/Bool.hs | 75 ++++++++--- src/Juvix/Compiler/Builtins/Effect.hs | 36 ++++++ src/Juvix/Compiler/Builtins/Nat.hs | 117 ++++++++++++------ src/Juvix/Compiler/Concrete/Data/Builtins.hs | 4 + .../Compiler/Core/Translation/FromInternal.hs | 12 ++ src/Juvix/Extra/Strings.hs | 8 +- test/Compilation/Positive.hs | 2 +- tests/Compilation/positive/out/test006.out | 2 + tests/Compilation/positive/test006.juvix | 7 +- tests/Internal/positive/BuiltinIf.juvix | 2 +- 15 files changed, 223 insertions(+), 68 deletions(-) diff --git a/c-runtime/builtins/bool.h b/c-runtime/builtins/bool.h index b5f00466a..29c15ece4 100644 --- a/c-runtime/builtins/bool.h +++ b/c-runtime/builtins/bool.h @@ -8,14 +8,14 @@ typedef bool prim_bool; #define prim_true true #define prim_false false -bool is_prim_true(prim_bool b) { - return b == prim_true; -} +bool is_prim_true(prim_bool b) { return b == prim_true; } -bool is_prim_false(prim_bool b) { - return b == prim_false; -} +bool is_prim_false(prim_bool b) { return b == prim_false; } #define prim_if(b, ifThen, ifElse) (b ? ifThen : ifElse) -#endif // BOOL_H_ +#define prim_or(a, b) ((a) || (b)) + +#define prim_and(a, b) ((a) && (b)) + +#endif // BOOL_H_ diff --git a/juvix-stdlib b/juvix-stdlib index 7e54415ff..d494ecbf9 160000 --- a/juvix-stdlib +++ b/juvix-stdlib @@ -1 +1 @@ -Subproject commit 7e54415ffa11a89f36c2110ed8923284a53b2ce0 +Subproject commit d494ecbf9bcf87a0ab7a1f2deb8d72d2d5fae7dd diff --git a/src/Juvix/Compiler/Abstract/Translation/FromConcrete.hs b/src/Juvix/Compiler/Abstract/Translation/FromConcrete.hs index 0ad837154..92a9381d8 100644 --- a/src/Juvix/Compiler/Abstract/Translation/FromConcrete.hs +++ b/src/Juvix/Compiler/Abstract/Translation/FromConcrete.hs @@ -257,6 +257,8 @@ registerBuiltinFunction d = \case BuiltinNatLt -> registerNatLt d BuiltinNatEq -> registerNatEq d BuiltinBoolIf -> registerIf d + BuiltinBoolOr -> registerOr d + BuiltinBoolAnd -> registerAnd d registerBuiltinAxiom :: (Members '[InfoTableBuilder, Error ScoperError, Builtins] r) => diff --git a/src/Juvix/Compiler/Backend/C/Data/BuiltinTable.hs b/src/Juvix/Compiler/Backend/C/Data/BuiltinTable.hs index 457cb89b1..98fdb1979 100644 --- a/src/Juvix/Compiler/Backend/C/Data/BuiltinTable.hs +++ b/src/Juvix/Compiler/Backend/C/Data/BuiltinTable.hs @@ -37,6 +37,8 @@ builtinFunctionName = \case BuiltinNatLt -> Just natlt BuiltinNatEq -> Just nateq BuiltinBoolIf -> Just boolif + BuiltinBoolOr -> Just boolor + BuiltinBoolAnd -> Just booland builtinName :: BuiltinPrim -> Maybe Text builtinName = \case diff --git a/src/Juvix/Compiler/Backend/C/Data/CNames.hs b/src/Juvix/Compiler/Backend/C/Data/CNames.hs index c6196d841..da99577fb 100644 --- a/src/Juvix/Compiler/Backend/C/Data/CNames.hs +++ b/src/Juvix/Compiler/Backend/C/Data/CNames.hs @@ -74,6 +74,12 @@ nateq = primPrefix "nateq" boolif :: Text boolif = primPrefix "if" +boolor :: Text +boolor = primPrefix "or" + +booland :: Text +booland = primPrefix "and" + funField :: Text funField = "fun" diff --git a/src/Juvix/Compiler/Builtins/Bool.hs b/src/Juvix/Compiler/Builtins/Bool.hs index 171c1028b..625141190 100644 --- a/src/Juvix/Compiler/Builtins/Bool.hs +++ b/src/Juvix/Compiler/Builtins/Bool.hs @@ -1,6 +1,5 @@ module Juvix.Compiler.Builtins.Bool where -import Data.HashSet qualified as HashSet import Juvix.Compiler.Abstract.Extra import Juvix.Compiler.Abstract.Pretty import Juvix.Compiler.Builtins.Effect @@ -36,34 +35,74 @@ registerIf f = do bool_ <- getBuiltinName (getLoc f) BuiltinBool true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse - vart <- freshVar "t" let if_ = f ^. funDefName - ty = f ^. funDefTypeSig - freeTVars = HashSet.fromList [vart] u = ExpressionUniverse (Universe {_universeLevel = Nothing, _universeLoc = error "Universe with no location"}) - unless (((u <>--> bool_ --> vart --> vart --> vart) ==% ty) freeTVars) (error "Bool if has the wrong type signature") - registerBuiltin BuiltinBoolIf if_ + vart <- freshVar "t" vare <- freshVar "e" hole <- freshHole let e = toExpression vare - freeVars = HashSet.fromList [vare] - (=%) :: (IsExpression a, IsExpression b) => a -> b -> Bool - a =% b = (a ==% b) freeVars exClauses :: [(Expression, Expression)] exClauses = [ (if_ @@ true_ @@ e @@ hole, e), (if_ @@ false_ @@ hole @@ e, e) ] - clauses :: [(Expression, Expression)] - clauses = - [ (clauseLhsAsExpression c, c ^. clauseBody) - | c <- toList (f ^. funDefClauses) + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinBoolIf, + _funInfoSignature = u <>--> bool_ --> vart --> vart --> vart, + _funInfoClauses = exClauses, + _funInfoFreeVars = [vare], + _funInfoFreeTypeVars = [vart] + } + +registerOr :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r () +registerOr f = do + bool_ <- getBuiltinName (getLoc f) BuiltinBool + true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue + false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse + let or_ = f ^. funDefName + vare <- freshVar "e" + hole <- freshHole + let e = toExpression vare + exClauses :: [(Expression, Expression)] + exClauses = + [ (or_ @@ true_ @@ hole, true_), + (or_ @@ false_ @@ e, e) ] - case zipExactMay exClauses clauses of - Nothing -> error "Bool if has the wrong number of clauses" - Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do - unless (exLhs =% lhs) (error "clause lhs does not match") - unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body) + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinBoolOr, + _funInfoSignature = bool_ --> bool_ --> bool_, + _funInfoClauses = exClauses, + _funInfoFreeVars = [vare], + _funInfoFreeTypeVars = [] + } + +registerAnd :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r () +registerAnd f = do + bool_ <- getBuiltinName (getLoc f) BuiltinBool + true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue + false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse + let and_ = f ^. funDefName + vare <- freshVar "e" + hole <- freshHole + let e = toExpression vare + exClauses :: [(Expression, Expression)] + exClauses = + [ (and_ @@ true_ @@ e, e), + (and_ @@ false_ @@ hole, false_) + ] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinBoolAnd, + _funInfoSignature = bool_ --> bool_ --> bool_, + _funInfoClauses = exClauses, + _funInfoFreeVars = [vare], + _funInfoFreeTypeVars = [] + } registerBoolPrint :: (Members '[Builtins] r) => AxiomDef -> Sem r () registerBoolPrint f = do diff --git a/src/Juvix/Compiler/Builtins/Effect.hs b/src/Juvix/Compiler/Builtins/Effect.hs index 4b209b207..cfed18739 100644 --- a/src/Juvix/Compiler/Builtins/Effect.hs +++ b/src/Juvix/Compiler/Builtins/Effect.hs @@ -3,7 +3,9 @@ module Juvix.Compiler.Builtins.Effect ) where +import Data.HashSet qualified as HashSet import Juvix.Compiler.Abstract.Extra +import Juvix.Compiler.Abstract.Pretty import Juvix.Compiler.Builtins.Error import Juvix.Prelude @@ -61,3 +63,37 @@ re = reinterpret $ \case runBuiltins :: (Member (Error JuvixError) r) => BuiltinsState -> Sem (Builtins ': r) a -> Sem r (BuiltinsState, a) runBuiltins s = runState s . re + +data FunInfo = FunInfo + { _funInfoDef :: FunctionDef, + _funInfoBuiltin :: BuiltinFunction, + _funInfoSignature :: Expression, + _funInfoClauses :: [(Expression, Expression)], + _funInfoFreeVars :: [VarName], + _funInfoFreeTypeVars :: [VarName] + } + +makeLenses ''FunInfo + +registerFun :: + Members '[Builtins, NameIdGen] r => + FunInfo -> + Sem r () +registerFun fi = do + let op = fi ^. funInfoDef . funDefName + ty = fi ^. funInfoDef . funDefTypeSig + sig = fi ^. funInfoSignature + unless ((sig ==% ty) (HashSet.fromList (fi ^. funInfoFreeTypeVars))) (error "builtin has the wrong type signature") + registerBuiltin (fi ^. funInfoBuiltin) op + let freeVars = HashSet.fromList (fi ^. funInfoFreeVars) + a =% b = (a ==% b) freeVars + clauses :: [(Expression, Expression)] + clauses = + [ (clauseLhsAsExpression c, c ^. clauseBody) + | c <- toList (fi ^. funInfoDef . funDefClauses) + ] + case zipExactMay (fi ^. funInfoClauses) clauses of + Nothing -> error "builtin has the wrong number of clauses" + Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do + unless (exLhs =% lhs) (error "clause lhs does not match") + unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body) diff --git a/src/Juvix/Compiler/Builtins/Nat.hs b/src/Juvix/Compiler/Builtins/Nat.hs index b58e62551..9cde3f121 100644 --- a/src/Juvix/Compiler/Builtins/Nat.hs +++ b/src/Juvix/Compiler/Builtins/Nat.hs @@ -1,6 +1,5 @@ module Juvix.Compiler.Builtins.Nat where -import Data.HashSet qualified as HashSet import Juvix.Compiler.Abstract.Extra import Juvix.Compiler.Abstract.Pretty import Juvix.Compiler.Builtins.Effect @@ -39,32 +38,6 @@ registerNatPrint f = do unless (f ^. axiomType === (nat --> io)) (error "Nat print has the wrong type signature") registerBuiltin BuiltinNatPrint (f ^. axiomName) -registerNatFun :: - (Members '[Builtins, NameIdGen] r) => - FunctionDef -> - BuiltinFunction -> - Expression -> - [(Expression, Expression)] -> - [VarName] -> - Sem r () -registerNatFun f blt sig exClauses fvs = do - let op = f ^. funDefName - ty = f ^. funDefTypeSig - unless (ty === sig) (error "builtin has the wrong type signature") - registerBuiltin blt op - let freeVars = HashSet.fromList fvs - a =% b = (a ==% b) freeVars - clauses :: [(Expression, Expression)] - clauses = - [ (clauseLhsAsExpression c, c ^. clauseBody) - | c <- toList (f ^. funDefClauses) - ] - case zipExactMay exClauses clauses of - Nothing -> error "builtin has the wrong number of clauses" - Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do - unless (exLhs =% lhs) (error "clause lhs does not match") - unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body) - registerNatPlus :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatPlus f = do nat <- getBuiltinName (getLoc f) BuiltinNat @@ -82,7 +55,15 @@ registerNatPlus f = do [ (zero .+. m, m), ((suc @@ n) .+. m, suc @@ (n .+. m)) ] - registerNatFun f BuiltinNatPlus (nat --> nat --> nat) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatPlus, + _funInfoSignature = nat --> nat --> nat, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatMul :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatMul f = do @@ -103,7 +84,15 @@ registerNatMul f = do [ (zero .*. h, zero), ((suc @@ n) .*. m, plus @@ m @@ (n .*. m)) ] - registerNatFun f BuiltinNatMul (nat --> nat --> nat) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatMul, + _funInfoSignature = nat --> nat --> nat, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatSub :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatSub f = do @@ -124,7 +113,15 @@ registerNatSub f = do (n .-. zero, n), ((suc @@ n) .-. (suc @@ m), n .-. m) ] - registerNatFun f BuiltinNatSub (nat --> nat --> nat) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatSub, + _funInfoSignature = nat --> nat --> nat, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatUDiv :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatUDiv f = do @@ -145,7 +142,15 @@ registerNatUDiv f = do [ (zero ./. h, zero), (n ./. m, suc @@ ((sub @@ n @@ m) ./. m)) ] - registerNatFun f BuiltinNatUDiv (nat --> nat --> nat) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatUDiv, + _funInfoSignature = nat --> nat --> nat, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatDiv :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatDiv f = do @@ -164,7 +169,15 @@ registerNatDiv f = do exClauses = [ (n ./. m, udiv @@ (sub @@ (suc @@ n) @@ m) @@ m) ] - registerNatFun f BuiltinNatDiv (nat --> nat --> nat) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatDiv, + _funInfoSignature = nat --> nat --> nat, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatMod :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatMod f = do @@ -180,7 +193,15 @@ registerNatMod f = do exClauses = [ (modop @@ n @@ m, sub @@ n @@ (mul @@ (divop @@ n @@ m) @@ m)) ] - registerNatFun f BuiltinNatMod (nat --> nat --> nat) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatMod, + _funInfoSignature = nat --> nat --> nat, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatLe :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatLe f = do @@ -204,7 +225,15 @@ registerNatLe f = do (h .<=. zero, false), ((suc @@ n) .<=. (suc @@ m), n .<=. m) ] - registerNatFun f BuiltinNatLe (nat --> nat --> tybool) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatLe, + _funInfoSignature = nat --> nat --> tybool, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatLt :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatLt f = do @@ -221,7 +250,15 @@ registerNatLt f = do exClauses = [ (lt @@ n @@ m, le @@ (suc @@ n) @@ m) ] - registerNatFun f BuiltinNatLt (nat --> nat --> tybool) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatLt, + _funInfoSignature = nat --> nat --> tybool, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } registerNatEq :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r () registerNatEq f = do @@ -246,4 +283,12 @@ registerNatEq f = do (h .==. zero, false), ((suc @@ n) .==. (suc @@ m), n .==. m) ] - registerNatFun f BuiltinNatEq (nat --> nat --> tybool) exClauses [varn, varm] + registerFun + FunInfo + { _funInfoDef = f, + _funInfoBuiltin = BuiltinNatEq, + _funInfoSignature = nat --> nat --> tybool, + _funInfoClauses = exClauses, + _funInfoFreeVars = [varn, varm], + _funInfoFreeTypeVars = [] + } diff --git a/src/Juvix/Compiler/Concrete/Data/Builtins.hs b/src/Juvix/Compiler/Concrete/Data/Builtins.hs index 6f9ee2837..aac1e08d0 100644 --- a/src/Juvix/Compiler/Concrete/Data/Builtins.hs +++ b/src/Juvix/Compiler/Concrete/Data/Builtins.hs @@ -72,6 +72,8 @@ data BuiltinFunction | BuiltinNatLt | BuiltinNatEq | BuiltinBoolIf + | BuiltinBoolOr + | BuiltinBoolAnd deriving stock (Show, Eq, Ord, Enum, Bounded, Generic, Data) instance Hashable BuiltinFunction @@ -88,6 +90,8 @@ instance Pretty BuiltinFunction where BuiltinNatLt -> Str.natLt BuiltinNatEq -> Str.natEq BuiltinBoolIf -> Str.boolIf + BuiltinBoolOr -> Str.boolOr + BuiltinBoolAnd -> Str.boolAnd data BuiltinAxiom = BuiltinNatPrint diff --git a/src/Juvix/Compiler/Core/Translation/FromInternal.hs b/src/Juvix/Compiler/Core/Translation/FromInternal.hs index f028a864e..097b6d210 100644 --- a/src/Juvix/Compiler/Core/Translation/FromInternal.hs +++ b/src/Juvix/Compiler/Core/Translation/FromInternal.hs @@ -661,6 +661,18 @@ goApplication a = do case as of (_ : v : b1 : b2 : xs) -> return (mkApps' (mkIf' sym v b1 b2) xs) _ -> error "if must be called with 3 arguments" + Just Internal.BuiltinBoolOr -> do + sym <- getBoolSymbol + as <- exprArgs + case as of + (x : y : xs) -> return (mkApps' (mkIf' sym x (mkConstr' (BuiltinTag TagTrue) []) y) xs) + _ -> error "|| must be called with 2 arguments" + Just Internal.BuiltinBoolAnd -> do + sym <- getBoolSymbol + as <- exprArgs + case as of + (x : y : xs) -> return (mkApps' (mkIf' sym x y (mkConstr' (BuiltinTag TagFalse) [])) xs) + _ -> error "&& must be called with 2 arguments" _ -> app _ -> app diff --git a/src/Juvix/Extra/Strings.hs b/src/Juvix/Extra/Strings.hs index cbf24c087..b9fe7255c 100644 --- a/src/Juvix/Extra/Strings.hs +++ b/src/Juvix/Extra/Strings.hs @@ -143,7 +143,13 @@ natEq = "nat-eq" boolIf :: (IsString s) => s boolIf = "bool-if" -builtin :: (IsString s) => s +boolOr :: IsString s => s +boolOr = "bool-or" + +boolAnd :: IsString s => s +boolAnd = "bool-and" + +builtin :: IsString s => s builtin = "builtin" type_ :: (IsString s) => s diff --git a/test/Compilation/Positive.hs b/test/Compilation/Positive.hs index 6369d3600..e8266f775 100644 --- a/test/Compilation/Positive.hs +++ b/test/Compilation/Positive.hs @@ -80,7 +80,7 @@ tests = $(mkRelFile "test005.juvix") $(mkRelFile "out/test005.out"), posTest - "If-then-else" + "If-then-else and lazy boolean operators" $(mkRelDir ".") $(mkRelFile "test006.juvix") $(mkRelFile "out/test006.out"), diff --git a/tests/Compilation/positive/out/test006.out b/tests/Compilation/positive/out/test006.out index 0cfbf0888..87f0e6408 100644 --- a/tests/Compilation/positive/out/test006.out +++ b/tests/Compilation/positive/out/test006.out @@ -1 +1,3 @@ 2 +true +false diff --git a/tests/Compilation/positive/test006.juvix b/tests/Compilation/positive/test006.juvix index 7f94ac288..a5ded90bd 100644 --- a/tests/Compilation/positive/test006.juvix +++ b/tests/Compilation/positive/test006.juvix @@ -1,4 +1,4 @@ --- if then else +-- if-then-else and lazy boolean operators module test006; open import Stdlib.Prelude; @@ -9,7 +9,8 @@ loop : Nat; loop := loop; main : IO; -main := printNatLn $ (if (3 > 0) 1 loop) + (if (2 < 1) loop (if (7 >= 8) loop 1)); - +main := printNatLn ((if (3 > 0) 1 loop) + (if (2 < 1) loop (if (7 >= 8) loop 1))) >> + printBoolLn (2 > 0 || loop == 0) >> + printBoolLn (2 < 0 && loop == 0); end; diff --git a/tests/Internal/positive/BuiltinIf.juvix b/tests/Internal/positive/BuiltinIf.juvix index ac3878481..e9afc0c10 100644 --- a/tests/Internal/positive/BuiltinIf.juvix +++ b/tests/Internal/positive/BuiltinIf.juvix @@ -3,6 +3,6 @@ module BuiltinIf; open import Stdlib.Prelude; main : Bool; -main := if false ((&&) false) ((&&) true) true; +main := if false (and false) (and true) true; end;