add testing for pivot selection

This commit is contained in:
Brendan Hansknecht 2024-07-28 00:03:46 -07:00
parent e722faaf58
commit 6d7d9e4e57
No known key found for this signature in database
GPG Key ID: 0EA784685083E75B

View File

@ -457,7 +457,7 @@ fn flux_partition(
comptime data_is_owned: bool,
inc_n_data: IncN,
) void {
var generic: i32 = 0;
var generic = false;
var pivot_ptr = pivot;
var x_ptr = x;
@ -472,9 +472,9 @@ fn flux_partition(
if (len <= 2048) {
median_of_nine(x_ptr, len, cmp, cmp_data, element_width, copy, data_is_owned, inc_n_data, pivot_ptr);
} else {
median_of_cbrt(array, swap, x_ptr, len, cmp, cmp_data, element_width, copy, data_is_owned, inc_n_data, &generic, pivot_ptr);
median_of_cube_root(array, swap, x_ptr, len, cmp, cmp_data, element_width, copy, data_is_owned, inc_n_data, &generic, pivot_ptr);
if (generic != 0) {
if (generic) {
if (x_ptr == swap) {
@memcpy(array[0..(len * element_width)], swap[0..(len * element_width)]);
}
@ -630,7 +630,11 @@ fn flux_reverse_partition(
// ================ Pivot Selection ===========================================
// Used for selecting the quicksort pivot for various sized arrays.
fn median_of_cbrt(
/// Returns the median of an array taking roughly cube root samples.
/// Only used for super large arrays, assumes the minimum cube root is 32.
/// Out is set to the median.
/// Generic is set to true if all elements are the same.
fn median_of_cube_root(
array: [*]u8,
swap: [*]u8,
x_ptr: [*]u8,
@ -641,7 +645,7 @@ fn median_of_cbrt(
copy: CopyFn,
comptime data_is_owned: bool,
inc_n_data: IncN,
generic: *i32,
generic: *bool,
out: [*]u8,
) void {
var cbrt: usize = 32;
@ -649,28 +653,25 @@ fn median_of_cbrt(
const div = len / cbrt;
// I assume using the pointer as an int is to add randomness here?
// Using a pointer to div as an int is to get a radom offset from 0 to div.
var arr_ptr = x_ptr + @intFromPtr(&div) / 16 % div;
var swap_ptr = if (x_ptr == array) swap else array;
for (0..cbrt) |cnt| {
copy(swap_ptr + cnt * element_width, arr_ptr);
arr_ptr += div;
arr_ptr += div * element_width;
}
cbrt /= 2;
quadsort_swap(swap_ptr, cbrt, swap_ptr + cbrt * 2 * element_width, cbrt, cmp, cmp_data, element_width, copy, data_is_owned, inc_n_data);
quadsort_swap(swap_ptr + cbrt * element_width, cbrt, swap_ptr + cbrt * 2 * element_width, cbrt, cmp, cmp_data, element_width, copy, data_is_owned, inc_n_data);
// 2 guaranteed compares
if (data_is_owned) {
inc_n_data(cmp_data, 2);
}
generic.* = @intFromBool(compare(cmp, cmp_data, swap_ptr + (cbrt * 2 - 1) * element_width, swap_ptr) != GT and compare(cmp, cmp_data, swap_ptr + (cbrt - 1) * element_width, swap_ptr) != GT);
generic.* = compare_inc(cmp, cmp_data, swap_ptr + (cbrt * 2 - 1) * element_width, swap_ptr, data_is_owned, inc_n_data) != GT and compare_inc(cmp, cmp_data, swap_ptr + (cbrt - 1) * element_width, swap_ptr, data_is_owned, inc_n_data) != GT;
binary_median(swap_ptr, swap_ptr + cbrt * element_width, cbrt, cmp, cmp_data, element_width, copy, data_is_owned, inc_n_data, out);
}
/// Returns the median of 9 evenly distributed elements from a list.
fn median_of_nine(
array: [*]u8,
len: usize,
@ -711,10 +712,12 @@ fn median_of_nine(
const y = compare(cmp, cmp_data, swap_ptr + 0 * element_width, swap_ptr + 2 * element_width) == GT;
const z = compare(cmp, cmp_data, swap_ptr + 1 * element_width, swap_ptr + 2 * element_width) == GT;
const index: usize = @intFromBool(x == y) + (@intFromBool(y) ^ @intFromBool(z));
const index: usize = @as(usize, @intFromBool(x == y)) + (@intFromBool(y) ^ @intFromBool(z));
copy(out, swap_ptr + index * element_width);
}
/// Ensures the middle two elements of the array are the middle two elements by sorting.
/// Does not care about the rest of the elements and can overwrite them.
fn trim_four(
initial_ptr_a: [*]u8,
cmp: CompareFn,
@ -752,17 +755,19 @@ fn trim_four(
}
{
const lte = compare(cmp, cmp_data, ptr_a, ptr_a + 2 * element_width) != GT;
const x = if (lte) element_width else 0;
const x = if (lte) 2 * element_width else 0;
copy(ptr_a + 2 * element_width, ptr_a + x);
ptr_a += 1;
ptr_a += element_width;
}
{
const gt = compare(cmp, cmp_data, ptr_a, ptr_a + 2 * element_width) == GT;
const x = if (gt) element_width else 0;
const x = if (gt) 2 * element_width else 0;
copy(ptr_a, ptr_a + x);
}
}
/// Attempts to find the median of 2 binary arrays of len.
/// Set out to the larger median from the two lists.
fn binary_median(
initial_ptr_a: [*]u8,
initial_ptr_b: [*]u8,
@ -777,16 +782,14 @@ fn binary_median(
) void {
var len = initial_len;
if (data_is_owned) {
// We need to increment log2 of n times.
// We can get that by counting leading zeros and of (top - 1).
// Needs to be `-1` so values that are powers of 2 don't sort up a bin.
// Then just add 1 back to the final result.
const log2 = @bitSizeOf(usize) - @clz(len - 1) + 1;
// We need to increment log2 of len times.
const log2 = @bitSizeOf(usize) - @clz(len);
inc_n_data(cmp_data, log2);
}
var ptr_a = initial_ptr_a;
var ptr_b = initial_ptr_b;
while (len / 2 != 0) : (len /= 2) {
len /= 2;
while (len != 0) : (len /= 2) {
if (compare(cmp, cmp_data, ptr_a, ptr_b) != GT) {
ptr_a += len * element_width;
} else {
@ -797,6 +800,138 @@ fn binary_median(
copy(out, from);
}
test "median_of_cube_root" {
var test_count: i64 = 0;
var out: i64 = 0;
var generic = false;
var swap: [32]i64 = undefined;
var swap_ptr = @as([*]u8, @ptrCast(&swap[0]));
{
var arr: [32]i64 = undefined;
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
arr = [32]i64{
//
1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31,
//
2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32,
};
median_of_cube_root(arr_ptr, swap_ptr, arr_ptr, 32, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&generic), @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 17);
try testing.expectEqual(generic, false);
for (0..32) |i| {
arr[i] = 7;
}
median_of_cube_root(arr_ptr, swap_ptr, arr_ptr, 32, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&generic), @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 7);
try testing.expectEqual(generic, true);
for (0..32) |i| {
arr[i] = 7 + @as(i64, @intCast(i % 2));
}
median_of_cube_root(arr_ptr, swap_ptr, arr_ptr, 32, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&generic), @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 8);
try testing.expectEqual(generic, false);
}
}
test "median_of_nine" {
var test_count: i64 = 0;
var out: i64 = 0;
{
var arr: [9]i64 = undefined;
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
arr = [9]i64{ 1, 2, 3, 4, 5, 6, 7, 8, 9 };
median_of_nine(arr_ptr, 10, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
// Note: median is not guaranteed to be extact. in this case:
// [2, 3], [6, 7] -> [3, 6] -> [3, 6, 9] -> 6
try testing.expectEqual(out, 6);
arr = [9]i64{ 1, 3, 5, 7, 9, 2, 4, 6, 8 };
median_of_nine(arr_ptr, 10, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
// Note: median is not guaranteed to be extact. in this case:
// [3, 5], [4, 6] -> [4, 5] -> [4, 5, 8] -> 5
try testing.expectEqual(out, 5);
arr = [9]i64{ 2, 3, 9, 4, 5, 7, 8, 6, 1 };
median_of_nine(arr_ptr, 10, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
// Note: median is not guaranteed to be extact. in this case:
// [3, 4], [5, 6] -> [4, 5] -> [1, 4, 5] -> 4
try testing.expectEqual(out, 4);
}
}
test "trim_four" {
var test_count: i64 = 0;
var arr: [4]i64 = undefined;
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
arr = [4]i64{ 1, 2, 3, 4 };
trim_four(arr_ptr, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data);
try testing.expectEqual(test_count, 0);
try testing.expectEqual(arr, [4]i64{ 1, 2, 3, 4 });
arr = [4]i64{ 2, 3, 1, 4 };
trim_four(arr_ptr, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data);
try testing.expectEqual(test_count, 0);
try testing.expectEqual(arr, [4]i64{ 2, 3, 2, 4 });
arr = [4]i64{ 4, 3, 2, 1 };
trim_four(arr_ptr, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data);
try testing.expectEqual(test_count, 0);
try testing.expectEqual(arr, [4]i64{ 3, 2, 3, 2 });
}
test "binary_median" {
var test_count: i64 = 0;
var out: i64 = 0;
{
var arr: [10]i64 = undefined;
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
arr = [10]i64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
binary_median(arr_ptr, arr_ptr + 5 * @sizeOf(i64), 5, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 6);
arr = [10]i64{ 1, 3, 5, 7, 9, 2, 4, 6, 8, 10 };
binary_median(arr_ptr, arr_ptr + 5 * @sizeOf(i64), 5, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 5);
}
{
var arr: [16]i64 = undefined;
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
arr = [16]i64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 };
binary_median(arr_ptr, arr_ptr + 8 * @sizeOf(i64), 8, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 9);
arr = [16]i64{ 1, 3, 5, 7, 9, 11, 13, 15, 2, 4, 6, 8, 10, 12, 14, 16 };
binary_median(arr_ptr, arr_ptr + 8 * @sizeOf(i64), 8, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 9);
arr = [16]i64{ 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8 };
binary_median(arr_ptr, arr_ptr + 8 * @sizeOf(i64), 8, &test_i64_compare_refcounted, @ptrCast(&test_count), @sizeOf(i64), &test_i64_copy, true, &test_inc_n_data, @ptrCast(&out));
try testing.expectEqual(test_count, 0);
try testing.expectEqual(out, 9);
}
}
// ================ Quadsort ==================================================
// The high level quadsort functions.