You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
4.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package class_2022_09_1_week;
// 来自学员问题
// 给你一个长度为n的数组并询问q次
// 每次询问区间[l,r]之间是否存在小于等于k个数的和大于等于x
// 每条查询返回true或者false
// 1 <= n, q <= 10^5
// k <= 10
// 1 <= x <= 10^8
import java.util.PriorityQueue;
public class Code04_QueryTopKSum {
public static class SegmentTree {
private int n;
private int k;
// private int[] max;
// private int[][] max = new int[][10];
private int[][] max;
private int[][] query;
public SegmentTree(int[] arr, int K) {
n = arr.length;
k = K;
max = new int[(n + 1) << 2][k];
query = new int[(n + 1) << 2][k];
for (int i = 0; i < n; i++) {
update(i, arr[i]);
}
}
public int topKSum(int l, int r) {
collect(l + 1, r + 1, 1, n, 1);
int sum = 0;
for (int num : query[1]) {
sum += num;
}
return sum;
}
private void update(int i, int v) {
update(i + 1, i + 1, v, 1, n, 1);
}
private void update(int L, int R, int C, int l, int r, int rt) {
if (L <= l && r <= R) {
max[rt][0] = C;
return;
}
int mid = (l + r) >> 1;
if (L <= mid) {
update(L, R, C, l, mid, rt << 1);
}
if (R > mid) {
update(L, R, C, mid + 1, r, rt << 1 | 1);
}
merge(max[rt], max[rt << 1], max[rt << 1 | 1]);
}
// father 要前k名
// left k名
// right k名
private void merge(int[] father, int[] left, int[] right) {
for (int i = 0, p1 = 0, p2 = 0; i < k; i++) {
if (left == null || p1 == k) {
father[i] = right[p2++];
} else if (right == null || p2 == k) {
father[i] = left[p1++];
} else {
father[i] = left[p1] >= right[p2] ? left[p1++] : right[p2++];
}
}
}
private void collect(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) {
for (int i = 0; i < k; i++) {
query[rt][i] = max[rt][i];
}
} else {
int mid = (l + r) >> 1;
boolean leftUpdate = false;
boolean rightUpdate = false;
if (L <= mid) {
leftUpdate = true;
collect(L, R, l, mid, rt << 1);
}
if (R > mid) {
rightUpdate = true;
collect(L, R, mid + 1, r, rt << 1 | 1);
}
merge(query[rt], leftUpdate ? query[rt << 1] : null, rightUpdate ? query[rt << 1 | 1] : null);
}
}
}
// 暴力实现的结构
// 为了验证
public static class Right {
public int[] arr;
public int k;
public Right(int[] nums, int K) {
k = K;
arr = new int[nums.length];
for (int i = 0; i < nums.length; i++) {
arr[i] = nums[i];
}
}
public int topKSum(int l, int r) {
PriorityQueue<Integer> heap = new PriorityQueue<>((a, b) -> b - a);
for (int i = l; i <= r; i++) {
heap.add(arr[i]);
}
int sum = 0;
for (int i = 0; i < k && !heap.isEmpty(); i++) {
sum += heap.poll();
}
return sum;
}
}
// 为了验证
public static int[] randomArray(int n, int v) {
int[] ans = new int[n];
for (int i = 0; i < n; i++) {
ans[i] = (int) (Math.random() * v) + 1;
}
return ans;
}
// 为了验证
public static void main(String[] args) {
int N = 100;
int K = 10;
int V = 100;
int testTimes = 5000;
int queryTimes = 100;
System.out.println("测试开始");
for (int i = 0; i < testTimes; i++) {
int n = (int) (Math.random() * N) + 1;
int k = (int) (Math.random() * K) + 1;
int[] arr = randomArray(n, V);
SegmentTree st = new SegmentTree(arr, k);
Right right = new Right(arr, k);
for (int j = 0; j < queryTimes; j++) {
int a = (int) (Math.random() * n);
int b = (int) (Math.random() * n);
int l = Math.min(a, b);
int r = Math.max(a, b);
int ans1 = st.topKSum(l, r);
int ans2 = right.topKSum(l, r);
if (ans1 != ans2) {
System.out.println("出错了!");
System.out.println(ans1);
System.out.println(ans2);
}
}
}
System.out.println("测试结束");
System.out.println("性能测试开始");
int n = 100000;
int k = 10;
int[] arr = randomArray(n, n);
int[][] queries = new int[n][2];
for (int i = 0; i < n; i++) {
int a = (int) (Math.random() * n);
int b = (int) (Math.random() * n);
queries[i][0] = Math.min(a, b);
queries[i][1] = Math.max(a, b);
}
System.out.println("数据规模 : " + n);
System.out.println("数值规模 : " + n);
System.out.println("查询规模 : " + n);
System.out.println("k规模 : " + k);
long start, end1, end2;
start = System.currentTimeMillis();
SegmentTree st = new SegmentTree(arr, k);
end1 = System.currentTimeMillis();
for (int i = 0; i < n; i++) {
st.topKSum(queries[i][0], queries[i][1]);
}
end2 = System.currentTimeMillis();
System.out.println("初始化运行时间 : " + (end1 - start) + " 毫秒");
System.out.println("执行查询运行时间 : " + (end2 - end1) + " 毫秒");
System.out.println("总共运行时间 : " + (end2 - start) + " 毫秒");
System.out.println("性能测试结束");
}
}