diff --git a/stdlib/map.ml b/stdlib/map.ml index 54eaf4fe8..b5d9a08ca 100644 --- a/stdlib/map.ml +++ b/stdlib/map.ml @@ -270,12 +270,13 @@ module Make(Ord: OrderedType) = struct let rec filter p = function Empty -> Empty - | Node(l, v, d, r, _) -> + | Node(l, v, d, r, _) as t -> (* call [p] in the expected left-to-right order *) let l' = filter p l in let pvd = p v d in let r' = filter p r in - if pvd then join l' v d r' else concat l' r' + if pvd then if l==l' && r==r' then t else join l' v d r' + else concat l' r' let rec partition p = function Empty -> (Empty, Empty) diff --git a/stdlib/map.mli b/stdlib/map.mli index b97a43b54..4a6ced5c8 100644 --- a/stdlib/map.mli +++ b/stdlib/map.mli @@ -134,7 +134,9 @@ module type S = val filter: (key -> 'a -> bool) -> 'a t -> 'a t (** [filter p m] returns the map with all the bindings in [m] - that satisfy predicate [p]. + that satisfy predicate [p]. If [p] satisfies every binding in [m], + [m] is returned unchanged (the result of the function is then + physically equal to [m]) @since 3.12.0 *) diff --git a/testsuite/tests/lib-set/testmap.ml b/testsuite/tests/lib-set/testmap.ml index 5d5972bc8..8792ae7f0 100644 --- a/testsuite/tests/lib-set/testmap.ml +++ b/testsuite/tests/lib-set/testmap.ml @@ -147,3 +147,11 @@ let () = assert (!m2 == !m1); assert(a2 -. a1 = a1 -. a0) + +let () = + (* check that filtering a map where all bindings are satisfied by + the given predicate returns the original map *) + let m1 = ref M.empty in + for i = 1 to 10 do m1 := M.add i (float i) !m1 done; + let m2 = M.filter (fun e _ -> e >= 0) !m1 in + assert (m2 == !m1)