diff --git a/stdlib/sort.ml b/stdlib/sort.ml index 2df38cf58..ef5acdacd 100644 --- a/stdlib/sort.ml +++ b/stdlib/sort.ml @@ -46,50 +46,47 @@ let swap arr i j = unsafe_set arr i (unsafe_get arr j); unsafe_set arr j tmp -let array order arr = +let array cmp arr = let rec qsort lo hi = - if hi <= lo then () - else if hi - lo < 5 then begin - (* Use insertion sort *) - for i = lo + 1 to hi do - let val_i = unsafe_get arr i in - if order val_i (unsafe_get arr (i - 1)) then begin - unsafe_set arr i (unsafe_get arr (i - 1)); - let j = ref (i - 1) in - while !j >= 1 && order val_i (unsafe_get arr (!j - 1)) do - unsafe_set arr !j (unsafe_get arr (!j - 1)); - decr j - done; - unsafe_set arr !j val_i - end - done - end else begin + if hi - lo >= 6 then begin let mid = (lo + hi) lsr 1 in - (* Select median value from among LO, MID, and HI *) - let pivotpos = - let vlo = unsafe_get arr lo - and vhi = unsafe_get arr hi - and vmid = unsafe_get arr mid in - if order vlo vmid then - if order vmid vhi then mid - else if order vlo vhi then hi else lo - else - if order vhi vmid then mid - else if order vhi vlo then hi else lo in - swap arr pivotpos hi; - let pivot = unsafe_get arr hi in - let i = ref lo and j = ref hi in + (* Select median value from among LO, MID, and HI. Rearrange + LO and HI so the three values are sorted. This lowers the + probability of picking a pathological pivot. It also + avoids extra comparisons on i and j in the two tight "while" + loops below. *) + if cmp (unsafe_get arr mid) (unsafe_get arr lo) then swap arr mid lo; + if cmp (unsafe_get arr hi) (unsafe_get arr mid) then begin + swap arr mid hi; + if cmp (unsafe_get arr mid) (unsafe_get arr lo) then swap arr mid lo + end; + let pivot = unsafe_get arr mid in + let i = ref (lo + 1) and j = ref (hi - 1) in while !i < !j do - while !i < hi && order (unsafe_get arr !i) pivot do incr i done; - while !j > lo && order pivot (unsafe_get arr !j) do decr j done; - if !i < !j then swap arr !i !j + while not (cmp pivot (unsafe_get arr !i)) do incr i done; + while not (cmp (unsafe_get arr !j) pivot) do decr j done; + if !i < !j then swap arr !i !j; + incr i; decr j done; - swap arr !i hi; - (* Recurse on larger half first *) - if (!i - 1) - lo >= hi - (!i + 1) then begin - qsort lo (!i - 1); qsort (!i + 1) hi + (* Recursion on smaller half, tail-call on larger half *) + if !j - lo <= hi - !i then begin + qsort lo !j; qsort !i hi end else begin - qsort (!i + 1) hi; qsort lo (!i - 1) + qsort !i hi; qsort lo !j end end in - qsort 0 (Array.length arr - 1) + qsort 0 (Array.length arr - 1); + (* Finish sorting by insertion sort *) + for i = 1 to Array.length arr - 1 do + let val_i = (unsafe_get arr i) in + if not (cmp (unsafe_get arr (i - 1)) val_i) then begin + unsafe_set arr i (unsafe_get arr (i - 1)); + let j = ref (i - 1) in + while !j >= 1 && not (cmp (unsafe_get arr (!j - 1)) val_i) do + unsafe_set arr !j (unsafe_get arr (!j - 1)); + decr j + done; + unsafe_set arr !j val_i + end + done +