AK: Guarantee a maximum stack depth for dual_pivot_quick_sort

When the two chosen pivots happen to be the smallest and largest
elements of the array, three partitions will be created, two of
size 0 and one of size n-2. If this happens on each recursive call
to dual_pivot_quick_sort, the stack depth will reach approximately n/2.

To avoid the stack from deepening, iteration can be used for the
largest of the three partitions. This ensures the stack depth
will only increase for partitions of size n/2 or smaller, which
results in a maximum stack depth of log(n).
This commit is contained in:
Mart G 2021-04-28 17:41:43 +02:00 committed by Andreas Kling
parent 67b0d04315
commit c9f3cc6dcc
Notes: sideshowbarker 2024-07-18 19:00:08 +09:00

View File

@ -18,62 +18,75 @@ namespace AK {
template<typename Collection, typename LessThan>
void dual_pivot_quick_sort(Collection& col, int start, int end, LessThan less_than)
{
int size = end - start + 1;
if (size <= 1) {
return;
}
if (size > 3) {
int third = size / 3;
if (less_than(col[start + third], col[end - third])) {
swap(col[start + third], col[start]);
swap(col[end - third], col[end]);
} else {
swap(col[start + third], col[end]);
swap(col[end - third], col[start]);
}
} else {
if (!less_than(col[start], col[end])) {
swap(col[start], col[end]);
}
}
int j = start + 1;
int k = start + 1;
int g = end - 1;
auto&& left_pivot = col[start];
auto&& right_pivot = col[end];
while (k <= g) {
if (less_than(col[k], left_pivot)) {
swap(col[k], col[j]);
j++;
} else if (!less_than(col[k], right_pivot)) {
while (!less_than(col[g], right_pivot) && k < g) {
g--;
while (start < end) {
int size = end - start + 1;
if (size > 3) {
int third = size / 3;
if (less_than(col[start + third], col[end - third])) {
swap(col[start + third], col[start]);
swap(col[end - third], col[end]);
} else {
swap(col[start + third], col[end]);
swap(col[end - third], col[start]);
}
swap(col[k], col[g]);
g--;
} else {
if (!less_than(col[start], col[end])) {
swap(col[start], col[end]);
}
}
int j = start + 1;
int k = start + 1;
int g = end - 1;
auto&& left_pivot = col[start];
auto&& right_pivot = col[end];
while (k <= g) {
if (less_than(col[k], left_pivot)) {
swap(col[k], col[j]);
j++;
} else if (!less_than(col[k], right_pivot)) {
while (!less_than(col[g], right_pivot) && k < g) {
g--;
}
swap(col[k], col[g]);
g--;
if (less_than(col[k], left_pivot)) {
swap(col[k], col[j]);
j++;
}
}
k++;
}
j--;
g++;
swap(col[start], col[j]);
swap(col[end], col[g]);
int left_pointer = j;
int right_pointer = g;
int left_size = left_pointer - start;
int middle_size = right_pointer - (left_pointer + 1);
int right_size = (end + 1) - (right_pointer + 1);
if (left_size >= middle_size && left_size >= right_size) {
dual_pivot_quick_sort(col, left_pointer + 1, right_pointer - 1, less_than);
dual_pivot_quick_sort(col, right_pointer + 1, end, less_than);
end = left_pointer - 1;
} else if (middle_size >= right_size) {
dual_pivot_quick_sort(col, start, left_pointer - 1, less_than);
dual_pivot_quick_sort(col, right_pointer + 1, end, less_than);
start = left_pointer + 1;
end = right_pointer - 1;
} else {
dual_pivot_quick_sort(col, start, left_pointer - 1, less_than);
dual_pivot_quick_sort(col, left_pointer + 1, right_pointer - 1, less_than);
start = right_pointer + 1;
}
k++;
}
j--;
g++;
swap(col[start], col[j]);
swap(col[end], col[g]);
int left_pointer = j;
int right_pointer = g;
dual_pivot_quick_sort(col, start, left_pointer - 1, less_than);
dual_pivot_quick_sort(col, left_pointer + 1, right_pointer - 1, less_than);
dual_pivot_quick_sort(col, right_pointer + 1, end, less_than);
}
template<typename Iterator, typename LessThan>