1
1
mirror of https://github.com/anoma/juvix.git synced 2025-01-05 22:46:08 +03:00

Lazy boolean operators (#1743)

Closes #1701
This commit is contained in:
Łukasz Czajka 2023-01-25 18:57:47 +01:00 committed by GitHub
parent cd2af04601
commit acea6615a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 223 additions and 68 deletions

View File

@ -8,14 +8,14 @@ typedef bool prim_bool;
#define prim_true true
#define prim_false false
bool is_prim_true(prim_bool b) {
return b == prim_true;
}
bool is_prim_true(prim_bool b) { return b == prim_true; }
bool is_prim_false(prim_bool b) {
return b == prim_false;
}
bool is_prim_false(prim_bool b) { return b == prim_false; }
#define prim_if(b, ifThen, ifElse) (b ? ifThen : ifElse)
#endif // BOOL_H_
#define prim_or(a, b) ((a) || (b))
#define prim_and(a, b) ((a) && (b))
#endif // BOOL_H_

@ -1 +1 @@
Subproject commit 7e54415ffa11a89f36c2110ed8923284a53b2ce0
Subproject commit d494ecbf9bcf87a0ab7a1f2deb8d72d2d5fae7dd

View File

@ -257,6 +257,8 @@ registerBuiltinFunction d = \case
BuiltinNatLt -> registerNatLt d
BuiltinNatEq -> registerNatEq d
BuiltinBoolIf -> registerIf d
BuiltinBoolOr -> registerOr d
BuiltinBoolAnd -> registerAnd d
registerBuiltinAxiom ::
(Members '[InfoTableBuilder, Error ScoperError, Builtins] r) =>

View File

@ -37,6 +37,8 @@ builtinFunctionName = \case
BuiltinNatLt -> Just natlt
BuiltinNatEq -> Just nateq
BuiltinBoolIf -> Just boolif
BuiltinBoolOr -> Just boolor
BuiltinBoolAnd -> Just booland
builtinName :: BuiltinPrim -> Maybe Text
builtinName = \case

View File

@ -74,6 +74,12 @@ nateq = primPrefix "nateq"
boolif :: Text
boolif = primPrefix "if"
boolor :: Text
boolor = primPrefix "or"
booland :: Text
booland = primPrefix "and"
funField :: Text
funField = "fun"

View File

@ -1,6 +1,5 @@
module Juvix.Compiler.Builtins.Bool where
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Abstract.Extra
import Juvix.Compiler.Abstract.Pretty
import Juvix.Compiler.Builtins.Effect
@ -36,34 +35,74 @@ registerIf f = do
bool_ <- getBuiltinName (getLoc f) BuiltinBool
true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue
false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse
vart <- freshVar "t"
let if_ = f ^. funDefName
ty = f ^. funDefTypeSig
freeTVars = HashSet.fromList [vart]
u = ExpressionUniverse (Universe {_universeLevel = Nothing, _universeLoc = error "Universe with no location"})
unless (((u <>--> bool_ --> vart --> vart --> vart) ==% ty) freeTVars) (error "Bool if has the wrong type signature")
registerBuiltin BuiltinBoolIf if_
vart <- freshVar "t"
vare <- freshVar "e"
hole <- freshHole
let e = toExpression vare
freeVars = HashSet.fromList [vare]
(=%) :: (IsExpression a, IsExpression b) => a -> b -> Bool
a =% b = (a ==% b) freeVars
exClauses :: [(Expression, Expression)]
exClauses =
[ (if_ @@ true_ @@ e @@ hole, e),
(if_ @@ false_ @@ hole @@ e, e)
]
clauses :: [(Expression, Expression)]
clauses =
[ (clauseLhsAsExpression c, c ^. clauseBody)
| c <- toList (f ^. funDefClauses)
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinBoolIf,
_funInfoSignature = u <>--> bool_ --> vart --> vart --> vart,
_funInfoClauses = exClauses,
_funInfoFreeVars = [vare],
_funInfoFreeTypeVars = [vart]
}
registerOr :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerOr f = do
bool_ <- getBuiltinName (getLoc f) BuiltinBool
true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue
false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse
let or_ = f ^. funDefName
vare <- freshVar "e"
hole <- freshHole
let e = toExpression vare
exClauses :: [(Expression, Expression)]
exClauses =
[ (or_ @@ true_ @@ hole, true_),
(or_ @@ false_ @@ e, e)
]
case zipExactMay exClauses clauses of
Nothing -> error "Bool if has the wrong number of clauses"
Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do
unless (exLhs =% lhs) (error "clause lhs does not match")
unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body)
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinBoolOr,
_funInfoSignature = bool_ --> bool_ --> bool_,
_funInfoClauses = exClauses,
_funInfoFreeVars = [vare],
_funInfoFreeTypeVars = []
}
registerAnd :: Members '[Builtins, NameIdGen] r => FunctionDef -> Sem r ()
registerAnd f = do
bool_ <- getBuiltinName (getLoc f) BuiltinBool
true_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolTrue
false_ <- toExpression <$> getBuiltinName (getLoc f) BuiltinBoolFalse
let and_ = f ^. funDefName
vare <- freshVar "e"
hole <- freshHole
let e = toExpression vare
exClauses :: [(Expression, Expression)]
exClauses =
[ (and_ @@ true_ @@ e, e),
(and_ @@ false_ @@ hole, false_)
]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinBoolAnd,
_funInfoSignature = bool_ --> bool_ --> bool_,
_funInfoClauses = exClauses,
_funInfoFreeVars = [vare],
_funInfoFreeTypeVars = []
}
registerBoolPrint :: (Members '[Builtins] r) => AxiomDef -> Sem r ()
registerBoolPrint f = do

View File

@ -3,7 +3,9 @@ module Juvix.Compiler.Builtins.Effect
)
where
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Abstract.Extra
import Juvix.Compiler.Abstract.Pretty
import Juvix.Compiler.Builtins.Error
import Juvix.Prelude
@ -61,3 +63,37 @@ re = reinterpret $ \case
runBuiltins :: (Member (Error JuvixError) r) => BuiltinsState -> Sem (Builtins ': r) a -> Sem r (BuiltinsState, a)
runBuiltins s = runState s . re
data FunInfo = FunInfo
{ _funInfoDef :: FunctionDef,
_funInfoBuiltin :: BuiltinFunction,
_funInfoSignature :: Expression,
_funInfoClauses :: [(Expression, Expression)],
_funInfoFreeVars :: [VarName],
_funInfoFreeTypeVars :: [VarName]
}
makeLenses ''FunInfo
registerFun ::
Members '[Builtins, NameIdGen] r =>
FunInfo ->
Sem r ()
registerFun fi = do
let op = fi ^. funInfoDef . funDefName
ty = fi ^. funInfoDef . funDefTypeSig
sig = fi ^. funInfoSignature
unless ((sig ==% ty) (HashSet.fromList (fi ^. funInfoFreeTypeVars))) (error "builtin has the wrong type signature")
registerBuiltin (fi ^. funInfoBuiltin) op
let freeVars = HashSet.fromList (fi ^. funInfoFreeVars)
a =% b = (a ==% b) freeVars
clauses :: [(Expression, Expression)]
clauses =
[ (clauseLhsAsExpression c, c ^. clauseBody)
| c <- toList (fi ^. funInfoDef . funDefClauses)
]
case zipExactMay (fi ^. funInfoClauses) clauses of
Nothing -> error "builtin has the wrong number of clauses"
Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do
unless (exLhs =% lhs) (error "clause lhs does not match")
unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body)

View File

@ -1,6 +1,5 @@
module Juvix.Compiler.Builtins.Nat where
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Abstract.Extra
import Juvix.Compiler.Abstract.Pretty
import Juvix.Compiler.Builtins.Effect
@ -39,32 +38,6 @@ registerNatPrint f = do
unless (f ^. axiomType === (nat --> io)) (error "Nat print has the wrong type signature")
registerBuiltin BuiltinNatPrint (f ^. axiomName)
registerNatFun ::
(Members '[Builtins, NameIdGen] r) =>
FunctionDef ->
BuiltinFunction ->
Expression ->
[(Expression, Expression)] ->
[VarName] ->
Sem r ()
registerNatFun f blt sig exClauses fvs = do
let op = f ^. funDefName
ty = f ^. funDefTypeSig
unless (ty === sig) (error "builtin has the wrong type signature")
registerBuiltin blt op
let freeVars = HashSet.fromList fvs
a =% b = (a ==% b) freeVars
clauses :: [(Expression, Expression)]
clauses =
[ (clauseLhsAsExpression c, c ^. clauseBody)
| c <- toList (f ^. funDefClauses)
]
case zipExactMay exClauses clauses of
Nothing -> error "builtin has the wrong number of clauses"
Just z -> forM_ z $ \((exLhs, exBody), (lhs, body)) -> do
unless (exLhs =% lhs) (error "clause lhs does not match")
unless (exBody =% body) (error $ "clause body does not match " <> ppTrace exBody <> " | " <> ppTrace body)
registerNatPlus :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatPlus f = do
nat <- getBuiltinName (getLoc f) BuiltinNat
@ -82,7 +55,15 @@ registerNatPlus f = do
[ (zero .+. m, m),
((suc @@ n) .+. m, suc @@ (n .+. m))
]
registerNatFun f BuiltinNatPlus (nat --> nat --> nat) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatPlus,
_funInfoSignature = nat --> nat --> nat,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatMul :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatMul f = do
@ -103,7 +84,15 @@ registerNatMul f = do
[ (zero .*. h, zero),
((suc @@ n) .*. m, plus @@ m @@ (n .*. m))
]
registerNatFun f BuiltinNatMul (nat --> nat --> nat) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatMul,
_funInfoSignature = nat --> nat --> nat,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatSub :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatSub f = do
@ -124,7 +113,15 @@ registerNatSub f = do
(n .-. zero, n),
((suc @@ n) .-. (suc @@ m), n .-. m)
]
registerNatFun f BuiltinNatSub (nat --> nat --> nat) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatSub,
_funInfoSignature = nat --> nat --> nat,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatUDiv :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatUDiv f = do
@ -145,7 +142,15 @@ registerNatUDiv f = do
[ (zero ./. h, zero),
(n ./. m, suc @@ ((sub @@ n @@ m) ./. m))
]
registerNatFun f BuiltinNatUDiv (nat --> nat --> nat) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatUDiv,
_funInfoSignature = nat --> nat --> nat,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatDiv :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatDiv f = do
@ -164,7 +169,15 @@ registerNatDiv f = do
exClauses =
[ (n ./. m, udiv @@ (sub @@ (suc @@ n) @@ m) @@ m)
]
registerNatFun f BuiltinNatDiv (nat --> nat --> nat) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatDiv,
_funInfoSignature = nat --> nat --> nat,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatMod :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatMod f = do
@ -180,7 +193,15 @@ registerNatMod f = do
exClauses =
[ (modop @@ n @@ m, sub @@ n @@ (mul @@ (divop @@ n @@ m) @@ m))
]
registerNatFun f BuiltinNatMod (nat --> nat --> nat) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatMod,
_funInfoSignature = nat --> nat --> nat,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatLe :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatLe f = do
@ -204,7 +225,15 @@ registerNatLe f = do
(h .<=. zero, false),
((suc @@ n) .<=. (suc @@ m), n .<=. m)
]
registerNatFun f BuiltinNatLe (nat --> nat --> tybool) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatLe,
_funInfoSignature = nat --> nat --> tybool,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatLt :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatLt f = do
@ -221,7 +250,15 @@ registerNatLt f = do
exClauses =
[ (lt @@ n @@ m, le @@ (suc @@ n) @@ m)
]
registerNatFun f BuiltinNatLt (nat --> nat --> tybool) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatLt,
_funInfoSignature = nat --> nat --> tybool,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}
registerNatEq :: (Members '[Builtins, NameIdGen] r) => FunctionDef -> Sem r ()
registerNatEq f = do
@ -246,4 +283,12 @@ registerNatEq f = do
(h .==. zero, false),
((suc @@ n) .==. (suc @@ m), n .==. m)
]
registerNatFun f BuiltinNatEq (nat --> nat --> tybool) exClauses [varn, varm]
registerFun
FunInfo
{ _funInfoDef = f,
_funInfoBuiltin = BuiltinNatEq,
_funInfoSignature = nat --> nat --> tybool,
_funInfoClauses = exClauses,
_funInfoFreeVars = [varn, varm],
_funInfoFreeTypeVars = []
}

View File

@ -72,6 +72,8 @@ data BuiltinFunction
| BuiltinNatLt
| BuiltinNatEq
| BuiltinBoolIf
| BuiltinBoolOr
| BuiltinBoolAnd
deriving stock (Show, Eq, Ord, Enum, Bounded, Generic, Data)
instance Hashable BuiltinFunction
@ -88,6 +90,8 @@ instance Pretty BuiltinFunction where
BuiltinNatLt -> Str.natLt
BuiltinNatEq -> Str.natEq
BuiltinBoolIf -> Str.boolIf
BuiltinBoolOr -> Str.boolOr
BuiltinBoolAnd -> Str.boolAnd
data BuiltinAxiom
= BuiltinNatPrint

View File

@ -661,6 +661,18 @@ goApplication a = do
case as of
(_ : v : b1 : b2 : xs) -> return (mkApps' (mkIf' sym v b1 b2) xs)
_ -> error "if must be called with 3 arguments"
Just Internal.BuiltinBoolOr -> do
sym <- getBoolSymbol
as <- exprArgs
case as of
(x : y : xs) -> return (mkApps' (mkIf' sym x (mkConstr' (BuiltinTag TagTrue) []) y) xs)
_ -> error "|| must be called with 2 arguments"
Just Internal.BuiltinBoolAnd -> do
sym <- getBoolSymbol
as <- exprArgs
case as of
(x : y : xs) -> return (mkApps' (mkIf' sym x y (mkConstr' (BuiltinTag TagFalse) [])) xs)
_ -> error "&& must be called with 2 arguments"
_ -> app
_ -> app

View File

@ -143,7 +143,13 @@ natEq = "nat-eq"
boolIf :: (IsString s) => s
boolIf = "bool-if"
builtin :: (IsString s) => s
boolOr :: IsString s => s
boolOr = "bool-or"
boolAnd :: IsString s => s
boolAnd = "bool-and"
builtin :: IsString s => s
builtin = "builtin"
type_ :: (IsString s) => s

View File

@ -80,7 +80,7 @@ tests =
$(mkRelFile "test005.juvix")
$(mkRelFile "out/test005.out"),
posTest
"If-then-else"
"If-then-else and lazy boolean operators"
$(mkRelDir ".")
$(mkRelFile "test006.juvix")
$(mkRelFile "out/test006.out"),

View File

@ -1 +1,3 @@
2
true
false

View File

@ -1,4 +1,4 @@
-- if then else
-- if-then-else and lazy boolean operators
module test006;
open import Stdlib.Prelude;
@ -9,7 +9,8 @@ loop : Nat;
loop := loop;
main : IO;
main := printNatLn $ (if (3 > 0) 1 loop) + (if (2 < 1) loop (if (7 >= 8) loop 1));
main := printNatLn ((if (3 > 0) 1 loop) + (if (2 < 1) loop (if (7 >= 8) loop 1))) >>
printBoolLn (2 > 0 || loop == 0) >>
printBoolLn (2 < 0 && loop == 0);
end;

View File

@ -3,6 +3,6 @@ module BuiltinIf;
open import Stdlib.Prelude;
main : Bool;
main := if false ((&&) false) ((&&) true) true;
main := if false (and false) (and true) true;
end;