Map.filter_map and Set.filter_map

master
Gabriel Scherer 2020-03-14 15:09:30 +01:00
parent e10e2bfa99
commit 1cd6e4451f
17 changed files with 103 additions and 15 deletions

View File

@ -112,6 +112,9 @@ Working version
- #7110: Added Printf.ikbprintf and Printf.ibprintf
(Muskan Garg, review by Gabriel Scherer and Florian Angeletti)
- #9365: Set.filter_map and Map.filter_map
(Gabriel Scherer, review by Stephen Dolan and Nicolás Ojeda Bär)
### Other libraries:
- #9106: Register printer for Unix_error in win32unix, as in unix.

View File

@ -688,9 +688,8 @@ let build_transient ~(backend : (module Backend_intf.S))
~root_symbol:(Compilenv.current_unit_symbol ())
in
let sets_of_closures =
Set_of_closures_id.Map.filter_map
function_declarations_map
~f:(fun key (fun_decls : Simple_value_approx.function_declarations) ->
function_declarations_map |> Set_of_closures_id.Map.filter_map
(fun key (fun_decls : Simple_value_approx.function_declarations) ->
if Set_of_closures_id.Set.mem key relevant_set_of_closures then
Some fun_decls
else if begin

View File

@ -573,8 +573,8 @@ let prepare_to_simplify_set_of_closures ~env
set_of_closures.free_vars
in
let specialised_args =
Variable.Map.filter_map set_of_closures.specialised_args
~f:(fun param (spec_to : Flambda.specialised_to) ->
set_of_closures.specialised_args |> Variable.Map.filter_map
(fun param (spec_to : Flambda.specialised_to) ->
let keep =
match only_for_function_decl with
| None -> true

View File

@ -33,8 +33,8 @@ module Transform = struct
what_to_specialise
else
let projections_by_function =
Variable.Map.filter_map set_of_closures.function_decls.funs
~f:(fun _fun_var (function_decl : Flambda.function_declaration) ->
set_of_closures.function_decls.funs |> Variable.Map.filter_map
(fun _fun_var (function_decl : Flambda.function_declaration) ->
if function_decl.stub then None
else
Some (Extract_projections.from_function_decl ~env

View File

@ -40,6 +40,7 @@ module type S =
val for_all: (key -> 'a -> bool) -> 'a t -> bool
val exists: (key -> 'a -> bool) -> 'a t -> bool
val filter: (key -> 'a -> bool) -> 'a t -> 'a t
val filter_map: (key -> 'a -> 'b option) -> 'a t -> 'b t
val partition: (key -> 'a -> bool) -> 'a t -> 'a t * 'a t
val cardinal: 'a t -> int
val bindings: 'a t -> (key * 'a) list
@ -425,6 +426,18 @@ module Make(Ord: OrderedType) = struct
if pvd then if l==l' && r==r' then m else join l' v d r'
else concat l' r'
let rec filter_map f = function
Empty -> Empty
| Node {l; v; d; r} ->
(* call [f] in the expected left-to-right order *)
let l' = filter_map f l in
let fvd = f v d in
let r' = filter_map f r in
begin match fvd with
| Some d' -> join l' v d' r'
| None -> concat l' r'
end
let rec partition p = function
Empty -> (Empty, Empty)
| Node {l; v; d; r} ->

View File

@ -179,6 +179,26 @@ module type S =
@before 4.03 Physical equality was not ensured.
*)
val filter_map: (key -> 'a -> 'b option) -> 'a t -> 'b t
(** [filter_map f m] applies the function [f] to every binding of
[m], and builds a map from the results. For each binding
[(k, v)] in the input map:
- if [f k v] is [None] then [k] is not in the result,
- if [f k v] is [Some v'] then the binding [(k, v')]
is in the output map.
For example, the following function on maps whose values are lists
{[
filter_map
(fun _k li -> match li with [] -> None | _::tl -> Some tl)
m
]}
drops all bindings of [m] whose value is an empty list, and pops
the first element of each value that is non-empty.
@since 4.11.0
*)
val partition: (key -> 'a -> bool) -> 'a t -> 'a t * 'a t
(** [partition p m] returns a pair of maps [(m1, m2)], where
[m1] contains all the bindings of [s] that satisfy the

View File

@ -152,6 +152,7 @@ module Map : sig
val for_all: f:(key -> 'a -> bool) -> 'a t -> bool
val exists: f:(key -> 'a -> bool) -> 'a t -> bool
val filter: f:(key -> 'a -> bool) -> 'a t -> 'a t
val filter_map: f:(key -> 'a -> 'b option) -> 'a t -> 'b t
val partition: f:(key -> 'a -> bool) -> 'a t -> 'a t * 'a t
val cardinal: 'a t -> int
val bindings: 'a t -> (key * 'a) list
@ -205,6 +206,7 @@ module Set : sig
val for_all : f:(elt -> bool) -> t -> bool
val exists : f:(elt -> bool) -> t -> bool
val filter : f:(elt -> bool) -> t -> t
val filter_map : f:(elt -> elt option) -> t -> t
val partition : f:(elt -> bool) -> t -> t * t
val cardinal : t -> int
val elements : t -> elt list

View File

@ -44,6 +44,7 @@ module type S =
val for_all: (elt -> bool) -> t -> bool
val exists: (elt -> bool) -> t -> bool
val filter: (elt -> bool) -> t -> t
val filter_map: (elt -> elt option) -> t -> t
val partition: (elt -> bool) -> t -> t * t
val cardinal: t -> int
val elements: t -> elt list
@ -530,6 +531,27 @@ module Make(Ord: OrderedType) =
if l == l' && v == v' && r == r' then t
else try_join l' v' r'
let try_concat t1 t2 =
match (t1, t2) with
(Empty, t) -> t
| (t, Empty) -> t
| (_, _) -> try_join t1 (min_elt t2) (remove_min_elt t2)
let rec filter_map f = function
| Empty -> Empty
| Node{l; v; r} as t ->
(* enforce left-to-right evaluation order *)
let l' = filter_map f l in
let v' = f v in
let r' = filter_map f r in
begin match v' with
| Some v' ->
if l == l' && v == v' && r == r' then t
else try_join l' v' r'
| None ->
try_concat l' r'
end
let of_sorted_list l =
let rec sub n l =
match n, l with

View File

@ -154,6 +154,22 @@ module type S =
physically equal to [s]).
@before 4.03 Physical equality was not ensured.*)
val filter_map: (elt -> elt option) -> t -> t
(** [filter_map f s] returns the set of all [v] such that
[f x = Some v] for some element [x] of [s].
For example,
{[filter_map (fun n -> if n mod 2 = 0 then Some (n / 2) else None) s]}
is the set of halves of the even elements of [s].
If no element of [s] is changed or dropped by [f] (if
[f x = Some x] for each element [x]), then
[s] is returned unchanged: the result of the function
is then physically equal to [s].
@since 4.11.0
*)
val partition: (elt -> bool) -> t -> t * t
(** [partition p s] returns a pair of sets [(s1, s2)], where
[s1] is the set of all the elements of [s] that satisfy the

View File

@ -75,6 +75,11 @@ let test x v s1 s2 =
(let p x y = x >= 3 && x <= 6 in
M.bindings(M.filter p s1) = List.filter (uncurry p) (M.bindings s1));
checkbool "filter_map"
(let f x y = if x >= 3 && x <= 6 then Some (2 * x) else None in
let f_on_pair (x, y) = Option.map (fun v -> (x, v)) (f x y) in
M.bindings(M.filter_map f s1) = List.filter_map f_on_pair (M.bindings s1));
checkbool "partition"
(let p x y = x >= 3 && x <= 6 in
let (st,sf) = M.partition p s1

View File

@ -89,6 +89,14 @@ let test x s1 s2 =
(let p x = x >= 3 && x <= 6 in
S.elements(S.filter p s1) = List.filter p (S.elements s1));
checkbool "filter_map"
(let f x = if x >= 3 && x <= 6 then Some (2 * x) else None in
S.elements(S.filter_map f s1) = List.filter_map f (S.elements s1));
checkbool "filter_map(==)"
(let f x = Some x in
S.filter_map f s1 == s1);
checkbool "partition"
(let p x = x >= 3 && x <= 6 in
let (st,sf) = S.partition p s1

View File

@ -321,6 +321,7 @@ module type MapT =
val for_all : (key -> 'a -> bool) -> 'a t -> bool
val exists : (key -> 'a -> bool) -> 'a t -> bool
val filter : (key -> 'a -> bool) -> 'a t -> 'a t
val filter_map : (key -> 'a -> 'b option) -> 'a t -> 'b t
val partition : (key -> 'a -> bool) -> 'a t -> 'a t * 'a t
val cardinal : 'a t -> int
val bindings : 'a t -> (key * 'a) list
@ -372,6 +373,7 @@ module SSMap :
val for_all : (key -> 'a -> bool) -> 'a t -> bool
val exists : (key -> 'a -> bool) -> 'a t -> bool
val filter : (key -> 'a -> bool) -> 'a t -> 'a t
val filter_map : (key -> 'a -> 'b option) -> 'a t -> 'b t
val partition : (key -> 'a -> bool) -> 'a t -> 'a t * 'a t
val cardinal : 'a t -> int
val bindings : 'a t -> (key * 'a) list

View File

@ -298,6 +298,7 @@ module StringSet :
val for_all : (elt -> bool) -> t -> bool
val exists : (elt -> bool) -> t -> bool
val filter : (elt -> bool) -> t -> t
val filter_map : (elt -> elt option) -> t -> t
val partition : (elt -> bool) -> t -> t * t
val cardinal : t -> int
val elements : t -> elt list
@ -343,6 +344,7 @@ module SSet :
val for_all : (elt -> bool) -> t -> bool
val exists : (elt -> bool) -> t -> bool
val filter : (elt -> bool) -> t -> t
val filter_map : (elt -> elt option) -> t -> t
val partition : (elt -> bool) -> t -> t * t
val cardinal : t -> int
val elements : t -> elt list
@ -420,6 +422,7 @@ module A :
val for_all : (elt -> bool) -> t -> bool
val exists : (elt -> bool) -> t -> bool
val filter : (elt -> bool) -> t -> t
val filter_map : (elt -> elt option) -> t -> t
val partition : (elt -> bool) -> t -> t * t
val cardinal : t -> int
val elements : t -> elt list
@ -532,6 +535,7 @@ module SInt :
val for_all : (elt -> bool) -> t -> bool
val exists : (elt -> bool) -> t -> bool
val filter : (elt -> bool) -> t -> t
val filter_map : (elt -> elt option) -> t -> t
val partition : (elt -> bool) -> t -> t * t
val cardinal : t -> int
val elements : t -> elt list

View File

@ -254,6 +254,7 @@ module MkT :
val for_all : (elt -> bool) -> t -> bool
val exists : (elt -> bool) -> t -> bool
val filter : (elt -> bool) -> t -> t
val filter_map : (elt -> elt option) -> t -> t
val partition : (elt -> bool) -> t -> t * t
val cardinal : t -> int
val elements : t -> elt list

View File

@ -34,6 +34,7 @@ module Core :
val for_all : (key -> 'a -> bool) -> 'a t -> bool
val exists : (key -> 'a -> bool) -> 'a t -> bool
val filter : (key -> 'a -> bool) -> 'a t -> 'a t
val filter_map : (key -> 'a -> 'b option) -> 'a t -> 'b t
val partition : (key -> 'a -> bool) -> 'a t -> 'a t * 'a t
val cardinal : 'a t -> key
val bindings : 'a t -> (key * 'a) list

View File

@ -43,7 +43,6 @@ module type Map = sig
with type key = T.t
and type 'a t = 'a Map.Make (T).t
val filter_map : 'a t -> f:(key -> 'a -> 'b option) -> 'b t
val of_list : (key * 'a) list -> 'a t
val disjoint_union :
@ -102,12 +101,6 @@ end
module Make_map (T : Thing) = struct
include Map.Make (T)
let filter_map t ~f =
fold (fun id v map ->
match f id v with
| None -> map
| Some r -> add id r map) t empty
let of_list l =
List.fold_left (fun map (id, v) -> add id v map) empty l

View File

@ -52,7 +52,6 @@ module type Map = sig
with type key = T.t
and type 'a t = 'a Map.Make (T).t
val filter_map : 'a t -> f:(key -> 'a -> 'b option) -> 'b t
val of_list : (key * 'a) list -> 'a t
(** [disjoint_union m1 m2] contains all bindings from [m1] and