From 441dd54ec37236bede03121439be737dd8d3502f Mon Sep 17 00:00:00 2001 From: vbot Date: Thu, 20 Jun 2024 15:38:21 +0200 Subject: [PATCH] Refactor suggestions --- compiler/catala_utils/suggestions.ml | 71 ++++++++-------------- compiler/catala_utils/suggestions.mli | 19 ++++-- compiler/desugared/from_surface.ml | 3 +- compiler/shared_ast/typing.ml | 10 ++- compiler/surface/parser_driver.ml | 13 ++-- tests/default/bad/verbose_errors.catala_en | 7 +-- 6 files changed, 60 insertions(+), 63 deletions(-) diff --git a/compiler/catala_utils/suggestions.ml b/compiler/catala_utils/suggestions.ml index 0b2551e0..ee6fca0d 100644 --- a/compiler/catala_utils/suggestions.ml +++ b/compiler/catala_utils/suggestions.ml @@ -48,51 +48,34 @@ let levenshtein_distance (s : string) (t : string) : int = d.(m).(n) -(*We create a list composed by strings that satisfy the following rule : they - have the same levenshtein distance, which is the minimum distance between the - reference word "keyword" and all the strings in "candidates" (with the - condition that this minimum is equal to or less than one third of the length - of keyword + 1, in order to get suggestions close to "keyword")*) -let suggestion_minimum_levenshtein_distance_association - (candidates : string list) - (keyword : string) : string list = - let rec strings_minimum_levenshtein_distance - (minimum : int) - (result : string list) - (candidates' : string list) : string list = - (*As we iterate through the "candidates'" list, we create a list "result" - with all strings that have the last minimum levenshtein distance found - ("minimum").*) - match candidates' with - (*When a new minimum levenshtein distance is found, the new result list is - our new element "current_string" followed by strings that have the same - minimum distance. It will be the "result" list if there is no levenshtein - distance smaller than this new minimum.*) - | current_string :: tail -> - let current_levenshtein_distance = - levenshtein_distance current_string keyword - in - if current_levenshtein_distance < minimum then - strings_minimum_levenshtein_distance current_levenshtein_distance - [current_string] tail - (*The "result" list is updated (we append "current_string" to "result") - when a new string shares the same minimum levenshtein distance - "minimum"*) - else if current_levenshtein_distance = minimum then - strings_minimum_levenshtein_distance minimum - (result @ [current_string]) - tail - (*If a levenshtein distance greater than the minimum is found, "result" - doesn't change*) - else strings_minimum_levenshtein_distance minimum result tail - (*The "result" list is returned at the end of the "candidates'" list.*) - | [] -> result +module M = Stdlib.Map.Make (Int) + +let compute_candidates (candidates : string list) (word : string) : + string list M.t = + List.fold_left + (fun m candidate -> + let distance = levenshtein_distance word candidate in + M.update distance + (function None -> Some [candidate] | Some l -> Some (candidate :: l)) + m) + M.empty candidates + +let best_candidates candidates word = + let candidates = compute_candidates candidates word in + M.choose_opt candidates |> function None -> [] | Some (_, l) -> List.rev l + +let sorted_candidates ?(max_elements = 5) suggs given = + let rec sub acc n = function + | [] -> List.rev acc + | x :: t when n > 0 -> sub (x :: acc) (pred n) t + | _ -> List.rev acc in - strings_minimum_levenshtein_distance - (1 + (String.length keyword / 3)) - (*In order to select suggestions that are not too far away from the - keyword*) - [] candidates + let candidates = + List.map + (fun (_, l) -> List.rev l) + (M.bindings (compute_candidates suggs given)) + in + List.concat candidates |> sub [] max_elements let format (ppf : Format.formatter) (suggestions_list : string list) = match suggestions_list with diff --git a/compiler/catala_utils/suggestions.mli b/compiler/catala_utils/suggestions.mli index 66527c5a..63ca1851 100644 --- a/compiler/catala_utils/suggestions.mli +++ b/compiler/catala_utils/suggestions.mli @@ -15,9 +15,20 @@ License for the specific language governing permissions and limitations under the License. *) -val suggestion_minimum_levenshtein_distance_association : - string list -> string -> string list -(**Returns a list of the closest words into {!name:candidates} to the keyword - {!name:keyword}*) +val levenshtein_distance : string -> string -> int +(** [levenshtein_distance w1 w2] computes the levenshtein distance separating + [w1] from [w2]. *) + +val best_candidates : string list -> string -> string list +(** [best_candidates suggestions word] returns the subset of elements in + [suggestions] that minimize the levenshtein distance to [word]. Multiple + candidates that have a same distance is possible. *) + +val sorted_candidates : + ?max_elements:int -> string list -> string -> string list +(** [sorted_candidates ?max_elements suggestions word] sorts the [suggestions] + list and retain at most [max_elements] (defaults to 5). This list is ordered + by their levenshtein distance to [word], i.e., the first elements are the + most similar. *) val format : Format.formatter -> string list -> unit diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 01d4ff47..740bf0e2 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -138,8 +138,7 @@ let raise_error_cons_not_found (constructor : string Mark.pos) = let constructors = Ident.Map.keys ctxt.local.constructor_idmap in let closest_constructors = - Suggestions.suggestion_minimum_levenshtein_distance_association constructors - (Mark.remove constructor) + Suggestions.best_candidates constructors (Mark.remove constructor) in Message.error ~pos_msg:(fun ppf -> Format.fprintf ppf "Here is your code :") diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index fd6fd80f..dc9d84d5 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -618,7 +618,10 @@ and typecheck_expr_top_down : "Variable @{%s@} is not a declared output of scope %a." field A.ScopeName.format scope_out ~suggestion: - (List.map A.StructField.to_string (A.StructField.Map.keys str)) + (Suggestions.sorted_candidates + (List.map A.StructField.to_string + (A.StructField.Map.keys str)) + field) | None -> Message.error ~extra_pos: @@ -629,7 +632,10 @@ and typecheck_expr_top_down : "Field@ @{\"%s\"@}@ does@ not@ belong@ to@ structure@ \ @{\"%a\"@}." field A.StructName.format name - ~suggestion:(A.Ident.Map.keys ctx.ctx_struct_fields)) + ~suggestion: + (Suggestions.sorted_candidates + (A.Ident.Map.keys ctx.ctx_struct_fields) + field)) in try A.StructName.Map.find name candidate_structs with A.StructName.Map.Not_found _ -> diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index 4c5140e3..d944f041 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -105,15 +105,16 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct let sorted_candidate_tokens lexbuf token_list env = let acceptable_tokens = - List.filter - (fun (_, t) -> - I.acceptable (I.input_needed env) t (fst (lexing_positions lexbuf))) + List.filter_map + (fun ((_, t) as elt) -> + if I.acceptable (I.input_needed env) t (fst (lexing_positions lexbuf)) + then Some elt + else None) token_list in + let lexeme = Utf8.lexeme lexbuf in let similar_acceptable_tokens = - Suggestions.suggestion_minimum_levenshtein_distance_association - (List.map (fun (s, _) -> s) acceptable_tokens) - (Utf8.lexeme lexbuf) + Suggestions.best_candidates (List.map fst acceptable_tokens) lexeme in let module S = Set.Make (String) in let s_toks = S.of_list similar_acceptable_tokens in diff --git a/tests/default/bad/verbose_errors.catala_en b/tests/default/bad/verbose_errors.catala_en index b592ca7d..08217c7c 100644 --- a/tests/default/bad/verbose_errors.catala_en +++ b/tests/default/bad/verbose_errors.catala_en @@ -50,11 +50,8 @@ $ catala test-scope A │ │ ‾ ├─ Article │ -│ Maybe you wanted to write : "field0", or "field1", or "field2", -│ or "field3", or "field4", or "field5", or "field6", or "field7", -│ or "field8", or "field9", or "field10", or "field11", or "field12", -│ or "field13", or "field14", or "field15", or "field16", or "field17", -│ or "field18", or "field19", or "o" ? +│ Maybe you wanted to write : "field0", or "field2", or "field10", +│ or "field1", or "field3" ? └─ #return code 123# ```