You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

215 lines
5.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package class002;
import java.util.Arrays;
import java.util.Comparator;
public class Code02_KthMinPair {
public static class Pair {
public int x;
public int y;
Pair(int x, int y) {
this.x = x;
this.y = y;
}
}
public static class PairComparator implements Comparator<Pair> {
@Override
public int compare(Pair arg0, Pair arg1) {
return arg0.x != arg1.x ? arg0.x - arg1.x : arg0.y - arg1.y;
}
}
// O(N^2 * log (N^2))的复杂度,你肯定过不了
// 返回的int[] 长度是2{3,1} int[2] = [3,1]
public static int[] kthMinPair1(int[] arr, int k) {
int N = arr.length;
if (k > N * N) {
return null;
}
Pair[] pairs = new Pair[N * N];
int index = 0;
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
pairs[index++] = new Pair(arr[i], arr[j]);
}
}
Arrays.sort(pairs, new PairComparator());
return new int[] { pairs[k - 1].x, pairs[k - 1].y };
}
// O(N*logN)的复杂度,你肯定过了
public static int[] kthMinPair2(int[] arr, int k) {
int N = arr.length;
if (k > N * N) {
return null;
}
// O(N*logN)
Arrays.sort(arr);
// 第K小的数值对第一维数字是什么 是arr中
int fristNum = arr[(k - 1) / N];
int lessFristNumSize = 0;// 数出比fristNum小的数有几个
int fristNumSize = 0; // 数出==fristNum的数有几个
// <= fristNum
for (int i = 0; i < N && arr[i] <= fristNum; i++) {
if (arr[i] < fristNum) {
lessFristNumSize++;
} else {
fristNumSize++;
}
}
int rest = k - (lessFristNumSize * N);
return new int[] { fristNum, arr[(rest - 1) / fristNumSize] };
}
// O(N)的复杂度,你肯定蒙了
public static int[] kthMinPair3(int[] arr, int k) {
int N = arr.length;
if (k > N * N) {
return null;
}
// 在无序数组中找到第K小的数返回值
// 第K小以1作为开始
int fristNum = getMinKthByBFPRT(arr, ((k - 1) / N) + 1);
int lessFristNumSize = 0;
int fristNumSize = 0;
for (int i = 0; i < N; i++) {
if (arr[i] < fristNum) {
lessFristNumSize++;
}
if (arr[i] == fristNum) {
fristNumSize++;
}
}
int rest = k - (lessFristNumSize * N);
return new int[] { fristNum, getMinKthByBFPRT(arr, ((rest - 1) / fristNumSize) + 1) };
}
// 利用bfprt算法求解是O(N)的过程
// 在arr上找到第K小的数并返回。K范围是[1,N],范围不是[0,N-1]
// 对你来讲,它可能永远不会被你想起,但确实本题最优解的算法原型
public static int getMinKthByBFPRT(int[] arr, int K) {
return select(arr, 0, arr.length - 1, K - 1);
}
public static int select(int[] arr, int begin, int end, int i) {
if (begin == end) {
return arr[begin];
}
int pivot = medianOfMedians(arr, begin, end);
int[] pivotRange = partition(arr, begin, end, pivot);
if (i >= pivotRange[0] && i <= pivotRange[1]) {
return arr[i];
} else if (i < pivotRange[0]) {
return select(arr, begin, pivotRange[0] - 1, i);
} else {
return select(arr, pivotRange[1] + 1, end, i);
}
}
public static int medianOfMedians(int[] arr, int begin, int end) {
int num = end - begin + 1;
int offset = num % 5 == 0 ? 0 : 1;
int[] mArr = new int[num / 5 + offset];
for (int i = 0; i < mArr.length; i++) {
int beginI = begin + i * 5;
int endI = beginI + 4;
mArr[i] = getMedian(arr, beginI, Math.min(end, endI));
}
return select(mArr, 0, mArr.length - 1, mArr.length / 2);
}
public static int[] partition(int[] arr, int begin, int end, int pivotValue) {
int small = begin - 1;
int cur = begin;
int big = end + 1;
while (cur != big) {
if (arr[cur] < pivotValue) {
swap(arr, ++small, cur++);
} else if (arr[cur] > pivotValue) {
swap(arr, cur, --big);
} else {
cur++;
}
}
int[] range = new int[2];
range[0] = small + 1;
range[1] = big - 1;
return range;
}
public static int getMedian(int[] arr, int begin, int end) {
insertionSort(arr, begin, end);
int sum = end + begin;
int mid = (sum / 2) + (sum % 2);
return arr[mid];
}
public static void insertionSort(int[] arr, int begin, int end) {
for (int i = begin + 1; i != end + 1; i++) {
for (int j = i; j != begin; j--) {
if (arr[j - 1] > arr[j]) {
swap(arr, j - 1, j);
} else {
break;
}
}
}
}
public static void swap(int[] arr, int index1, int index2) {
int tmp = arr[index1];
arr[index1] = arr[index2];
arr[index2] = tmp;
}
// 为了测试,生成值也随机,长度也随机的随机数组
public static int[] getRandomArray(int max, int len) {
int[] arr = new int[(int) (Math.random() * len) + 1];
for (int i = 0; i < arr.length; i++) {
arr[i] = (int) (Math.random() * max) - (int) (Math.random() * max);
}
return arr;
}
// 为了测试
public static int[] copyArray(int[] arr) {
if (arr == null) {
return null;
}
int[] res = new int[arr.length];
for (int i = 0; i < arr.length; i++) {
res[i] = arr[i];
}
return res;
}
// 随机测试了百万组,保证三种方法都是对的
public static void main(String[] args) {
int max = 100;
int len = 30;
int testTimes = 100000;
System.out.println("test bagin, time times : " + testTimes);
for (int i = 0; i < testTimes; i++) {
int[] arr = getRandomArray(max, len);
int[] arr1 = copyArray(arr);
int[] arr2 = copyArray(arr);
int[] arr3 = copyArray(arr);
int N = arr.length * arr.length;
int k = (int) (Math.random() * N) + 1;
int[] ans1 = kthMinPair1(arr1, k);
int[] ans2 = kthMinPair2(arr2, k);
int[] ans3 = kthMinPair3(arr3, k);
if (ans1[0] != ans2[0] || ans2[0] != ans3[0] || ans1[1] != ans2[1] || ans2[1] != ans3[1]) {
System.out.println("Oops!");
}
}
System.out.println("test finish");
}
}