From 1cd6e4451f04af4871675d0eecd48b3bac262b67 Mon Sep 17 00:00:00 2001 From: Gabriel Scherer Date: Sat, 14 Mar 2020 15:09:30 +0100 Subject: [PATCH] Map.filter_map and Set.filter_map --- Changes | 3 +++ middle_end/flambda/build_export_info.ml | 5 ++--- middle_end/flambda/inline_and_simplify_aux.ml | 4 ++-- middle_end/flambda/unbox_specialised_args.ml | 4 ++-- stdlib/map.ml | 13 +++++++++++ stdlib/map.mli | 20 +++++++++++++++++ stdlib/moreLabels.mli | 2 ++ stdlib/set.ml | 22 +++++++++++++++++++ stdlib/set.mli | 16 ++++++++++++++ testsuite/tests/lib-set/testmap.ml | 5 +++++ testsuite/tests/lib-set/testset.ml | 8 +++++++ .../typing-implicit_unpack/implicit_unpack.ml | 2 ++ testsuite/tests/typing-modules/aliases.ml | 4 ++++ testsuite/tests/typing-modules/pr7818.ml | 1 + .../short-paths.compilers.reference | 1 + utils/identifiable.ml | 7 ------ utils/identifiable.mli | 1 - 17 files changed, 103 insertions(+), 15 deletions(-) diff --git a/Changes b/Changes index f39ee0cec..ad0c182bb 100644 --- a/Changes +++ b/Changes @@ -112,6 +112,9 @@ Working version - #7110: Added Printf.ikbprintf and Printf.ibprintf (Muskan Garg, review by Gabriel Scherer and Florian Angeletti) +- #9365: Set.filter_map and Map.filter_map + (Gabriel Scherer, review by Stephen Dolan and Nicolás Ojeda Bär) + ### Other libraries: - #9106: Register printer for Unix_error in win32unix, as in unix. diff --git a/middle_end/flambda/build_export_info.ml b/middle_end/flambda/build_export_info.ml index 67fea2db6..2025feddc 100644 --- a/middle_end/flambda/build_export_info.ml +++ b/middle_end/flambda/build_export_info.ml @@ -688,9 +688,8 @@ let build_transient ~(backend : (module Backend_intf.S)) ~root_symbol:(Compilenv.current_unit_symbol ()) in let sets_of_closures = - Set_of_closures_id.Map.filter_map - function_declarations_map - ~f:(fun key (fun_decls : Simple_value_approx.function_declarations) -> + function_declarations_map |> Set_of_closures_id.Map.filter_map + (fun key (fun_decls : Simple_value_approx.function_declarations) -> if Set_of_closures_id.Set.mem key relevant_set_of_closures then Some fun_decls else if begin diff --git a/middle_end/flambda/inline_and_simplify_aux.ml b/middle_end/flambda/inline_and_simplify_aux.ml index bb725e8c6..07ae4d006 100644 --- a/middle_end/flambda/inline_and_simplify_aux.ml +++ b/middle_end/flambda/inline_and_simplify_aux.ml @@ -573,8 +573,8 @@ let prepare_to_simplify_set_of_closures ~env set_of_closures.free_vars in let specialised_args = - Variable.Map.filter_map set_of_closures.specialised_args - ~f:(fun param (spec_to : Flambda.specialised_to) -> + set_of_closures.specialised_args |> Variable.Map.filter_map + (fun param (spec_to : Flambda.specialised_to) -> let keep = match only_for_function_decl with | None -> true diff --git a/middle_end/flambda/unbox_specialised_args.ml b/middle_end/flambda/unbox_specialised_args.ml index 70eb87601..20d69c1d6 100644 --- a/middle_end/flambda/unbox_specialised_args.ml +++ b/middle_end/flambda/unbox_specialised_args.ml @@ -33,8 +33,8 @@ module Transform = struct what_to_specialise else let projections_by_function = - Variable.Map.filter_map set_of_closures.function_decls.funs - ~f:(fun _fun_var (function_decl : Flambda.function_declaration) -> + set_of_closures.function_decls.funs |> Variable.Map.filter_map + (fun _fun_var (function_decl : Flambda.function_declaration) -> if function_decl.stub then None else Some (Extract_projections.from_function_decl ~env diff --git a/stdlib/map.ml b/stdlib/map.ml index 0883ba109..479f2646e 100644 --- a/stdlib/map.ml +++ b/stdlib/map.ml @@ -40,6 +40,7 @@ module type S = val for_all: (key -> 'a -> bool) -> 'a t -> bool val exists: (key -> 'a -> bool) -> 'a t -> bool val filter: (key -> 'a -> bool) -> 'a t -> 'a t + val filter_map: (key -> 'a -> 'b option) -> 'a t -> 'b t val partition: (key -> 'a -> bool) -> 'a t -> 'a t * 'a t val cardinal: 'a t -> int val bindings: 'a t -> (key * 'a) list @@ -425,6 +426,18 @@ module Make(Ord: OrderedType) = struct if pvd then if l==l' && r==r' then m else join l' v d r' else concat l' r' + let rec filter_map f = function + Empty -> Empty + | Node {l; v; d; r} -> + (* call [f] in the expected left-to-right order *) + let l' = filter_map f l in + let fvd = f v d in + let r' = filter_map f r in + begin match fvd with + | Some d' -> join l' v d' r' + | None -> concat l' r' + end + let rec partition p = function Empty -> (Empty, Empty) | Node {l; v; d; r} -> diff --git a/stdlib/map.mli b/stdlib/map.mli index 2dc955abb..6e238fc06 100644 --- a/stdlib/map.mli +++ b/stdlib/map.mli @@ -179,6 +179,26 @@ module type S = @before 4.03 Physical equality was not ensured. *) + val filter_map: (key -> 'a -> 'b option) -> 'a t -> 'b t + (** [filter_map f m] applies the function [f] to every binding of + [m], and builds a map from the results. For each binding + [(k, v)] in the input map: + - if [f k v] is [None] then [k] is not in the result, + - if [f k v] is [Some v'] then the binding [(k, v')] + is in the output map. + + For example, the following function on maps whose values are lists + {[ + filter_map + (fun _k li -> match li with [] -> None | _::tl -> Some tl) + m + ]} + drops all bindings of [m] whose value is an empty list, and pops + the first element of each value that is non-empty. + + @since 4.11.0 + *) + val partition: (key -> 'a -> bool) -> 'a t -> 'a t * 'a t (** [partition p m] returns a pair of maps [(m1, m2)], where [m1] contains all the bindings of [s] that satisfy the diff --git a/stdlib/moreLabels.mli b/stdlib/moreLabels.mli index 08bc0f4d9..eae749c71 100644 --- a/stdlib/moreLabels.mli +++ b/stdlib/moreLabels.mli @@ -152,6 +152,7 @@ module Map : sig val for_all: f:(key -> 'a -> bool) -> 'a t -> bool val exists: f:(key -> 'a -> bool) -> 'a t -> bool val filter: f:(key -> 'a -> bool) -> 'a t -> 'a t + val filter_map: f:(key -> 'a -> 'b option) -> 'a t -> 'b t val partition: f:(key -> 'a -> bool) -> 'a t -> 'a t * 'a t val cardinal: 'a t -> int val bindings: 'a t -> (key * 'a) list @@ -205,6 +206,7 @@ module Set : sig val for_all : f:(elt -> bool) -> t -> bool val exists : f:(elt -> bool) -> t -> bool val filter : f:(elt -> bool) -> t -> t + val filter_map : f:(elt -> elt option) -> t -> t val partition : f:(elt -> bool) -> t -> t * t val cardinal : t -> int val elements : t -> elt list diff --git a/stdlib/set.ml b/stdlib/set.ml index 6c8fdce83..d8b8a4595 100644 --- a/stdlib/set.ml +++ b/stdlib/set.ml @@ -44,6 +44,7 @@ module type S = val for_all: (elt -> bool) -> t -> bool val exists: (elt -> bool) -> t -> bool val filter: (elt -> bool) -> t -> t + val filter_map: (elt -> elt option) -> t -> t val partition: (elt -> bool) -> t -> t * t val cardinal: t -> int val elements: t -> elt list @@ -530,6 +531,27 @@ module Make(Ord: OrderedType) = if l == l' && v == v' && r == r' then t else try_join l' v' r' + let try_concat t1 t2 = + match (t1, t2) with + (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> try_join t1 (min_elt t2) (remove_min_elt t2) + + let rec filter_map f = function + | Empty -> Empty + | Node{l; v; r} as t -> + (* enforce left-to-right evaluation order *) + let l' = filter_map f l in + let v' = f v in + let r' = filter_map f r in + begin match v' with + | Some v' -> + if l == l' && v == v' && r == r' then t + else try_join l' v' r' + | None -> + try_concat l' r' + end + let of_sorted_list l = let rec sub n l = match n, l with diff --git a/stdlib/set.mli b/stdlib/set.mli index dd7128094..91e392386 100644 --- a/stdlib/set.mli +++ b/stdlib/set.mli @@ -154,6 +154,22 @@ module type S = physically equal to [s]). @before 4.03 Physical equality was not ensured.*) + val filter_map: (elt -> elt option) -> t -> t + (** [filter_map f s] returns the set of all [v] such that + [f x = Some v] for some element [x] of [s]. + + For example, + {[filter_map (fun n -> if n mod 2 = 0 then Some (n / 2) else None) s]} + is the set of halves of the even elements of [s]. + + If no element of [s] is changed or dropped by [f] (if + [f x = Some x] for each element [x]), then + [s] is returned unchanged: the result of the function + is then physically equal to [s]. + + @since 4.11.0 + *) + val partition: (elt -> bool) -> t -> t * t (** [partition p s] returns a pair of sets [(s1, s2)], where [s1] is the set of all the elements of [s] that satisfy the diff --git a/testsuite/tests/lib-set/testmap.ml b/testsuite/tests/lib-set/testmap.ml index 0be0410b8..500f00b0c 100644 --- a/testsuite/tests/lib-set/testmap.ml +++ b/testsuite/tests/lib-set/testmap.ml @@ -75,6 +75,11 @@ let test x v s1 s2 = (let p x y = x >= 3 && x <= 6 in M.bindings(M.filter p s1) = List.filter (uncurry p) (M.bindings s1)); + checkbool "filter_map" + (let f x y = if x >= 3 && x <= 6 then Some (2 * x) else None in + let f_on_pair (x, y) = Option.map (fun v -> (x, v)) (f x y) in + M.bindings(M.filter_map f s1) = List.filter_map f_on_pair (M.bindings s1)); + checkbool "partition" (let p x y = x >= 3 && x <= 6 in let (st,sf) = M.partition p s1 diff --git a/testsuite/tests/lib-set/testset.ml b/testsuite/tests/lib-set/testset.ml index b998875ea..36d450eb1 100644 --- a/testsuite/tests/lib-set/testset.ml +++ b/testsuite/tests/lib-set/testset.ml @@ -89,6 +89,14 @@ let test x s1 s2 = (let p x = x >= 3 && x <= 6 in S.elements(S.filter p s1) = List.filter p (S.elements s1)); + checkbool "filter_map" + (let f x = if x >= 3 && x <= 6 then Some (2 * x) else None in + S.elements(S.filter_map f s1) = List.filter_map f (S.elements s1)); + + checkbool "filter_map(==)" + (let f x = Some x in + S.filter_map f s1 == s1); + checkbool "partition" (let p x = x >= 3 && x <= 6 in let (st,sf) = S.partition p s1 diff --git a/testsuite/tests/typing-implicit_unpack/implicit_unpack.ml b/testsuite/tests/typing-implicit_unpack/implicit_unpack.ml index bd256f2c0..04334d668 100644 --- a/testsuite/tests/typing-implicit_unpack/implicit_unpack.ml +++ b/testsuite/tests/typing-implicit_unpack/implicit_unpack.ml @@ -321,6 +321,7 @@ module type MapT = val for_all : (key -> 'a -> bool) -> 'a t -> bool val exists : (key -> 'a -> bool) -> 'a t -> bool val filter : (key -> 'a -> bool) -> 'a t -> 'a t + val filter_map : (key -> 'a -> 'b option) -> 'a t -> 'b t val partition : (key -> 'a -> bool) -> 'a t -> 'a t * 'a t val cardinal : 'a t -> int val bindings : 'a t -> (key * 'a) list @@ -372,6 +373,7 @@ module SSMap : val for_all : (key -> 'a -> bool) -> 'a t -> bool val exists : (key -> 'a -> bool) -> 'a t -> bool val filter : (key -> 'a -> bool) -> 'a t -> 'a t + val filter_map : (key -> 'a -> 'b option) -> 'a t -> 'b t val partition : (key -> 'a -> bool) -> 'a t -> 'a t * 'a t val cardinal : 'a t -> int val bindings : 'a t -> (key * 'a) list diff --git a/testsuite/tests/typing-modules/aliases.ml b/testsuite/tests/typing-modules/aliases.ml index 2f2cfd243..aac8c2a02 100644 --- a/testsuite/tests/typing-modules/aliases.ml +++ b/testsuite/tests/typing-modules/aliases.ml @@ -298,6 +298,7 @@ module StringSet : val for_all : (elt -> bool) -> t -> bool val exists : (elt -> bool) -> t -> bool val filter : (elt -> bool) -> t -> t + val filter_map : (elt -> elt option) -> t -> t val partition : (elt -> bool) -> t -> t * t val cardinal : t -> int val elements : t -> elt list @@ -343,6 +344,7 @@ module SSet : val for_all : (elt -> bool) -> t -> bool val exists : (elt -> bool) -> t -> bool val filter : (elt -> bool) -> t -> t + val filter_map : (elt -> elt option) -> t -> t val partition : (elt -> bool) -> t -> t * t val cardinal : t -> int val elements : t -> elt list @@ -420,6 +422,7 @@ module A : val for_all : (elt -> bool) -> t -> bool val exists : (elt -> bool) -> t -> bool val filter : (elt -> bool) -> t -> t + val filter_map : (elt -> elt option) -> t -> t val partition : (elt -> bool) -> t -> t * t val cardinal : t -> int val elements : t -> elt list @@ -532,6 +535,7 @@ module SInt : val for_all : (elt -> bool) -> t -> bool val exists : (elt -> bool) -> t -> bool val filter : (elt -> bool) -> t -> t + val filter_map : (elt -> elt option) -> t -> t val partition : (elt -> bool) -> t -> t * t val cardinal : t -> int val elements : t -> elt list diff --git a/testsuite/tests/typing-modules/pr7818.ml b/testsuite/tests/typing-modules/pr7818.ml index 0fafb5816..f80f7e7df 100644 --- a/testsuite/tests/typing-modules/pr7818.ml +++ b/testsuite/tests/typing-modules/pr7818.ml @@ -254,6 +254,7 @@ module MkT : val for_all : (elt -> bool) -> t -> bool val exists : (elt -> bool) -> t -> bool val filter : (elt -> bool) -> t -> t + val filter_map : (elt -> elt option) -> t -> t val partition : (elt -> bool) -> t -> t * t val cardinal : t -> int val elements : t -> elt list diff --git a/testsuite/tests/typing-short-paths/short-paths.compilers.reference b/testsuite/tests/typing-short-paths/short-paths.compilers.reference index 47f15b3d9..1619e340f 100644 --- a/testsuite/tests/typing-short-paths/short-paths.compilers.reference +++ b/testsuite/tests/typing-short-paths/short-paths.compilers.reference @@ -34,6 +34,7 @@ module Core : val for_all : (key -> 'a -> bool) -> 'a t -> bool val exists : (key -> 'a -> bool) -> 'a t -> bool val filter : (key -> 'a -> bool) -> 'a t -> 'a t + val filter_map : (key -> 'a -> 'b option) -> 'a t -> 'b t val partition : (key -> 'a -> bool) -> 'a t -> 'a t * 'a t val cardinal : 'a t -> key val bindings : 'a t -> (key * 'a) list diff --git a/utils/identifiable.ml b/utils/identifiable.ml index e82390ad0..9bbfb6573 100644 --- a/utils/identifiable.ml +++ b/utils/identifiable.ml @@ -43,7 +43,6 @@ module type Map = sig with type key = T.t and type 'a t = 'a Map.Make (T).t - val filter_map : 'a t -> f:(key -> 'a -> 'b option) -> 'b t val of_list : (key * 'a) list -> 'a t val disjoint_union : @@ -102,12 +101,6 @@ end module Make_map (T : Thing) = struct include Map.Make (T) - let filter_map t ~f = - fold (fun id v map -> - match f id v with - | None -> map - | Some r -> add id r map) t empty - let of_list l = List.fold_left (fun map (id, v) -> add id v map) empty l diff --git a/utils/identifiable.mli b/utils/identifiable.mli index 4e2607115..0da5a6619 100644 --- a/utils/identifiable.mli +++ b/utils/identifiable.mli @@ -52,7 +52,6 @@ module type Map = sig with type key = T.t and type 'a t = 'a Map.Make (T).t - val filter_map : 'a t -> f:(key -> 'a -> 'b option) -> 'b t val of_list : (key * 'a) list -> 'a t (** [disjoint_union m1 m2] contains all bindings from [m1] and