catala/compiler/scopelang/dependency.ml
Louis Gesbert 41d6d3cbe9 Make scopes directly callable
Quite a few changes are included here, some of which have some extra
implications visible in the language:

- adds the `Scope of { -- input_v: value; ... }` construct in the language

- handle it down the pipeline:
  * `ScopeCall` in the surface AST
  * `EScopeCall` in desugared and scopelang
  * expressions are now traversed to detect dependencies between scopes
  * transformed into a normal function call in dcalc

- defining a scope now implicitely defines a structure with the same name, with
  the output variables of the scope defined as fields. This allows us to type
  the return value from a scope call and access its fields easily.
  * the implications are mostly in surface/name_resolution.ml code-wise
  * the `Scope_out` struct that was defined in scope_to_dcalc is no longer
    needed/used and the fields are no longer renamed (changes some outputs; the
    explicit suffix for variables with multiple states is ignored as well)
  * one benefit is that disambiguation works just like for structures when there
    are conflicts on field names
  * however, it's now a conflict if a scope and a structure have the same
    name (side-note: issues with conflicting enum / struct names or scope
    variables / subscope names were silent and are now properly reported)

- you can consequently use scope names as types for variables as well. Writing
  literals is not allowed though, they can only be obtained by calling the
  scope.

Remaining TODOs:

- context variables are not handled properly at the moment

- error handling on invalid calls

- tests show a small error message regression; lots of examples will need
  tweaking to avoid scope/struct name or struct fields / output variable
  conflicts

- add a `->` syntax to make struct field access distinct from scope output var
  access, enforced with typing. This is expected to reduce confusion of users
  and add a little typing precision.

- document the new syntax & implications (tutorial, cheat-sheet)

- a consequence of the changes is that subscope variables also can now be typed.
  A possible future evolution / simplification would be to rewrite subscopes as
  explicit scope calls early in the pipeline. That could also allow to manipulate
  them as expressions (bind them in let-ins, return them...)
2022-10-21 17:17:26 +02:00

276 lines
9.5 KiB
OCaml

(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Graph representation of the dependencies between scopes in the Catala
program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils
open Shared_ast
module SVertex = ScopeName
(** On the edges, the label is the expression responsible for the use of the
function *)
module SEdge = struct
type t = Pos.t
let compare = compare
let default = Pos.no_pos
end
module SDependencies =
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (SVertex) (SEdge)
module STopologicalTraversal = Graph.Topological.Make (SDependencies)
module SSCC = Graph.Components.Make (SDependencies)
(** Tarjan's stongly connected components algorithm, provided by OCamlGraph *)
let rec expr_used_scopes e =
let recurse_subterms e =
Expr.shallow_fold
(fun e -> ScopeMap.union (fun _ x _ -> Some x) (expr_used_scopes e))
e ScopeMap.empty
in
match e with
| (EScopeCall (scope, _), m) as e ->
ScopeMap.add scope (Expr.mark_pos m) (recurse_subterms e)
| EAbs (binder, _), _ ->
let _, body = Bindlib.unmbind binder in
expr_used_scopes body
| e -> recurse_subterms e
let rule_used_scopes = function
| Ast.Assertion e | Ast.Definition (_, _, _, e) ->
(* TODO: maybe this info could be passed on from previous passes without
walking through all exprs again *)
expr_used_scopes e
| Ast.Call (subscope, subindex, _) ->
ScopeMap.singleton subscope
(Marked.get_mark (SubScopeName.get_info subindex))
let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t =
let g = SDependencies.empty in
let g =
ScopeMap.fold
(fun v _ g -> SDependencies.add_vertex g v)
prgm.program_scopes g
in
ScopeMap.fold
(fun scope_name scope g ->
List.fold_left
(fun g rule ->
let used_scopes = rule_used_scopes rule in
if ScopeMap.mem scope_name used_scopes then
Errors.raise_spanned_error
(Marked.get_mark (ScopeName.get_info scope.Ast.scope_decl_name))
"The scope %a is calling into itself as a subscope, which is \
forbidden since Catala does not provide recursion"
ScopeName.format_t scope.Ast.scope_decl_name;
ScopeMap.fold
(fun used_scope pos g ->
let edge = SDependencies.E.create used_scope pos scope_name in
SDependencies.add_edge_e g edge)
used_scopes g)
g scope.Ast.scope_decl_rules)
prgm.program_scopes g
let check_for_cycle_in_scope (g : SDependencies.t) : unit =
(* if there is a cycle, there will be an strongly connected component of
cardinality > 1 *)
let sccs = SSCC.scc_list g in
if List.length sccs < SDependencies.nb_vertex g then
let scc = List.find (fun scc -> List.length scc > 1) sccs in
let spans =
List.flatten
(List.map
(fun v ->
let var_str, var_info =
Format.asprintf "%a" ScopeName.format_t v, ScopeName.get_info v
in
let succs = SDependencies.succ_e g v in
let _, edge_pos, succ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs
in
let succ_str = Format.asprintf "%a" ScopeName.format_t succ in
[
( Some ("Cycle variable " ^ var_str ^ ", declared:"),
Marked.get_mark var_info );
( Some
("Used here in the definition of another cycle variable "
^ succ_str
^ ":"),
edge_pos );
])
scc)
in
Errors.raise_multispanned_error spans
"Cyclic dependency detected between scopes!"
let get_scope_ordering (g : SDependencies.t) : ScopeName.t list =
List.rev (STopologicalTraversal.fold (fun sd acc -> sd :: acc) g [])
module TVertex = struct
type t = Struct of StructName.t | Enum of EnumName.t
let hash x =
match x with Struct x -> StructName.hash x | Enum x -> EnumName.hash x
let compare x y =
match x, y with
| Struct x, Struct y -> StructName.compare x y
| Enum x, Enum y -> EnumName.compare x y
| Struct _, Enum _ -> 1
| Enum _, Struct _ -> -1
let equal x y =
match x, y with
| Struct x, Struct y -> StructName.compare x y = 0
| Enum x, Enum y -> EnumName.compare x y = 0
| _ -> false
let format_t (fmt : Format.formatter) (x : t) : unit =
match x with
| Struct x -> StructName.format_t fmt x
| Enum x -> EnumName.format_t fmt x
let get_info (x : t) =
match x with
| Struct x -> StructName.get_info x
| Enum x -> EnumName.get_info x
end
module TVertexSet = Set.Make (TVertex)
(** On the edges, the label is the expression responsible for the use of the
function *)
module TEdge = struct
type t = Pos.t
let compare = compare
let default = Pos.no_pos
end
module TDependencies =
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (TVertex) (TEdge)
module TTopologicalTraversal = Graph.Topological.Make (TDependencies)
module TSCC = Graph.Components.Make (TDependencies)
(** Tarjan's stongly connected components algorithm, provided by OCamlGraph *)
let rec get_structs_or_enums_in_type (t : typ) : TVertexSet.t =
match Marked.unmark t with
| TStruct s -> TVertexSet.singleton (TVertex.Struct s)
| TEnum e -> TVertexSet.singleton (TVertex.Enum e)
| TArrow (t1, t2) ->
TVertexSet.union
(get_structs_or_enums_in_type t1)
(get_structs_or_enums_in_type t2)
| TLit _ | TAny -> TVertexSet.empty
| TOption t1 | TArray t1 -> get_structs_or_enums_in_type t1
| TTuple ts ->
List.fold_left
(fun acc t -> TVertexSet.union acc (get_structs_or_enums_in_type t))
TVertexSet.empty ts
let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
=
let g = TDependencies.empty in
let g =
StructMap.fold
(fun s fields g ->
List.fold_left
(fun g (_, typ) ->
let def = TVertex.Struct s in
let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in
TVertexSet.fold
(fun used g ->
if TVertex.equal used def then
Errors.raise_spanned_error (Marked.get_mark typ)
"The type %a is defined using itself, which is forbidden \
since Catala does not provide recursive types"
TVertex.format_t used
else
let edge =
TDependencies.E.create used (Marked.get_mark typ) def
in
TDependencies.add_edge_e g edge)
used g)
g fields)
structs g
in
let g =
EnumMap.fold
(fun e cases g ->
List.fold_left
(fun g (_, typ) ->
let def = TVertex.Enum e in
let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in
TVertexSet.fold
(fun used g ->
if TVertex.equal used def then
Errors.raise_spanned_error (Marked.get_mark typ)
"The type %a is defined using itself, which is forbidden \
since Catala does not provide recursive types"
TVertex.format_t used
else
let edge =
TDependencies.E.create used (Marked.get_mark typ) def
in
TDependencies.add_edge_e g edge)
used g)
g cases)
enums g
in
g
let check_type_cycles (structs : struct_ctx) (enums : enum_ctx) : TVertex.t list
=
let g = build_type_graph structs enums in
(* if there is a cycle, there will be an strongly connected component of
cardinality > 1 *)
let sccs = TSCC.scc_list g in
(if List.length sccs < TDependencies.nb_vertex g then
let scc = List.find (fun scc -> List.length scc > 1) sccs in
let spans =
List.flatten
(List.map
(fun v ->
let var_str, var_info =
Format.asprintf "%a" TVertex.format_t v, TVertex.get_info v
in
let succs = TDependencies.succ_e g v in
let _, edge_pos, succ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs
in
let succ_str = Format.asprintf "%a" TVertex.format_t succ in
[
( Some ("Cycle type " ^ var_str ^ ", declared:"),
Marked.get_mark var_info );
( Some
("Used here in the definition of another cycle type "
^ succ_str
^ ":"),
edge_pos );
])
scc)
in
Errors.raise_multispanned_error spans
"Cyclic dependency detected between types!");
List.rev (TTopologicalTraversal.fold (fun v acc -> v :: acc) g [])