You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

417 lines
9.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package class21;
import java.util.HashMap;
public class TreeChainPartition {
public static class TreeChain {
// 时间戳 0 1 2 3 4
private int tim;
// 节点个数是n节点编号是1~n
private int n;
// 谁是头
private int h;
// 朴素树结构
private int[][] tree;
// 权重数组 原始的0节点权重是6 -> val[1] = 6
private int[] val;
// father数组一个平移因为标号要+1
private int[] fa;
// 深度数组!
private int[] dep;
// son[i] = 0 i这个节点没有儿子
// son[i] != 0 j i这个节点重儿子是j
private int[] son;
// siz[i] i这个节点为头的子树有多少个节点
private int[] siz;
// top[i] = j i这个节点所在的重链头是j
private int[] top;
// dfn[i] = j i这个节点在dfs序中是第j个
private int[] dfn;
// 如果原来的节点a权重是10
// 如果a节点在dfs序中是第5个节点, tnw[5] = 10
private int[] tnw;
// 线段树在tnw上玩连续的区间查询或者更新
private SegmentTree seg;
public TreeChain(int[] father, int[] values) {
// 原始的树 tree弄好了可以从i这个点找到下级的直接孩子
// 上面的一大堆结构准备好了空间values -> val
// 找到头部点
initTree(father, values);
// fa;
// dep;
// son;
// siz;
dfs1(h, 0);
// top;
// dfn;
// tnw;
dfs2(h, h);
seg = new SegmentTree(tnw);
seg.build(1, n, 1);
}
private void initTree(int[] father, int[] values) {
tim = 0;
n = father.length + 1;
tree = new int[n][];
val = new int[n];
fa = new int[n];
dep = new int[n];
son = new int[n];
siz = new int[n];
top = new int[n];
dfn = new int[n];
tnw = new int[n--];
int[] cnum = new int[n];
for (int i = 0; i < n; i++) {
val[i + 1] = values[i];
}
for (int i = 0; i < n; i++) {
if (father[i] == i) {
h = i + 1;
} else {
cnum[father[i]]++;
}
}
tree[0] = new int[0];
for (int i = 0; i < n; i++) {
tree[i + 1] = new int[cnum[i]];
}
for (int i = 0; i < n; i++) {
if (i + 1 != h) {
tree[father[i] + 1][--cnum[father[i]]] = i + 1;
}
}
}
// u 当前节点
// f u的父节点
private void dfs1(int u, int f) {
fa[u] = f;
dep[u] = dep[f] + 1;
siz[u] = 1;
int maxSize = -1;
for (int v : tree[u]) { // 遍历u节点所有的直接孩子
dfs1(v, u);
siz[u] += siz[v];
if (siz[v] > maxSize) {
maxSize = siz[v];
son[u] = v;
}
}
}
// u当前节点
// t是u所在重链的头部
private void dfs2(int u, int t) {
dfn[u] = ++tim;
top[u] = t;
tnw[tim] = val[u];
if (son[u] != 0) { // 如果u有儿子 siz[u] > 1
dfs2(son[u], t);
for (int v : tree[u]) {
if (v != son[u]) {
dfs2(v, v);
}
}
}
}
// head为头的子树上所有节点值+value
// 因为节点经过平移所以head(原始节点) -> head(平移节点)
public void addSubtree(int head, int value) {
// 原始点编号 -> 平移编号
head++;
// 平移编号 -> dfs编号 dfn[head]
seg.add(dfn[head], dfn[head] + siz[head] - 1, value, 1, n, 1);
}
public int querySubtree(int head) {
head++;
return seg.query(dfn[head], dfn[head] + siz[head] - 1, 1, n, 1);
}
public void addChain(int a, int b, int v) {
a++;
b++;
while (top[a] != top[b]) {
if (dep[top[a]] > dep[top[b]]) {
seg.add(dfn[top[a]], dfn[a], v, 1, n, 1);
a = fa[top[a]];
} else {
seg.add(dfn[top[b]], dfn[b], v, 1, n, 1);
b = fa[top[b]];
}
}
if (dep[a] > dep[b]) {
seg.add(dfn[b], dfn[a], v, 1, n, 1);
} else {
seg.add(dfn[a], dfn[b], v, 1, n, 1);
}
}
public int queryChain(int a, int b) {
a++;
b++;
int ans = 0;
while (top[a] != top[b]) {
if (dep[top[a]] > dep[top[b]]) {
ans += seg.query(dfn[top[a]], dfn[a], 1, n, 1);
a = fa[top[a]];
} else {
ans += seg.query(dfn[top[b]], dfn[b], 1, n, 1);
b = fa[top[b]];
}
}
if (dep[a] > dep[b]) {
ans += seg.query(dfn[b], dfn[a], 1, n, 1);
} else {
ans += seg.query(dfn[a], dfn[b], 1, n, 1);
}
return ans;
}
}
public static class SegmentTree {
private int MAXN;
private int[] arr;
private int[] sum;
private int[] lazy;
public SegmentTree(int[] origin) {
MAXN = origin.length;
arr = origin;
sum = new int[MAXN << 2];
lazy = new int[MAXN << 2];
}
private void pushUp(int rt) {
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
private void pushDown(int rt, int ln, int rn) {
if (lazy[rt] != 0) {
lazy[rt << 1] += lazy[rt];
sum[rt << 1] += lazy[rt] * ln;
lazy[rt << 1 | 1] += lazy[rt];
sum[rt << 1 | 1] += lazy[rt] * rn;
lazy[rt] = 0;
}
}
public void build(int l, int r, int rt) {
if (l == r) {
sum[rt] = arr[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, rt << 1);
build(mid + 1, r, rt << 1 | 1);
pushUp(rt);
}
public void add(int L, int R, int C, int l, int r, int rt) {
if (L <= l && r <= R) {
sum[rt] += C * (r - l + 1);
lazy[rt] += C;
return;
}
int mid = (l + r) >> 1;
pushDown(rt, mid - l + 1, r - mid);
if (L <= mid) {
add(L, R, C, l, mid, rt << 1);
}
if (R > mid) {
add(L, R, C, mid + 1, r, rt << 1 | 1);
}
pushUp(rt);
}
public int query(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) {
return sum[rt];
}
int mid = (l + r) >> 1;
pushDown(rt, mid - l + 1, r - mid);
int ans = 0;
if (L <= mid) {
ans += query(L, R, l, mid, rt << 1);
}
if (R > mid) {
ans += query(L, R, mid + 1, r, rt << 1 | 1);
}
return ans;
}
}
// 为了测试,这个结构是暴力但正确的方法
public static class Right {
private int n;
private int[][] tree;
private int[] fa;
private int[] val;
private HashMap<Integer, Integer> path;
public Right(int[] father, int[] value) {
n = father.length;
tree = new int[n][];
fa = new int[n];
val = new int[n];
for (int i = 0; i < n; i++) {
fa[i] = father[i];
val[i] = value[i];
}
int[] help = new int[n];
int h = 0;
for (int i = 0; i < n; i++) {
if (father[i] == i) {
h = i;
} else {
help[father[i]]++;
}
}
for (int i = 0; i < n; i++) {
tree[i] = new int[help[i]];
}
for (int i = 0; i < n; i++) {
if (i != h) {
tree[father[i]][--help[father[i]]] = i;
}
}
path = new HashMap<>();
}
public void addSubtree(int head, int value) {
val[head] += value;
for (int next : tree[head]) {
addSubtree(next, value);
}
}
public int querySubtree(int head) {
int ans = val[head];
for (int next : tree[head]) {
ans += querySubtree(next);
}
return ans;
}
public void addChain(int a, int b, int v) {
path.clear();
path.put(a, null);
while (a != fa[a]) {
path.put(fa[a], a);
a = fa[a];
}
while (!path.containsKey(b)) {
val[b] += v;
b = fa[b];
}
val[b] += v;
while (path.get(b) != null) {
b = path.get(b);
val[b] += v;
}
}
public int queryChain(int a, int b) {
path.clear();
path.put(a, null);
while (a != fa[a]) {
path.put(fa[a], a);
a = fa[a];
}
int ans = 0;
while (!path.containsKey(b)) {
ans += val[b];
b = fa[b];
}
ans += val[b];
while (path.get(b) != null) {
b = path.get(b);
ans += val[b];
}
return ans;
}
}
// 为了测试
// 随机生成N个节点树可能是多叉树并且一定不是森林
// 输入参数N要大于0
public static int[] generateFatherArray(int N) {
int[] order = new int[N];
for (int i = 0; i < N; i++) {
order[i] = i;
}
for (int i = N - 1; i >= 0; i--) {
swap(order, i, (int) (Math.random() * (i + 1)));
}
int[] ans = new int[N];
ans[order[0]] = order[0];
for (int i = 1; i < N; i++) {
ans[order[i]] = order[(int) (Math.random() * i)];
}
return ans;
}
// 为了测试
public static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
// 为了测试
public static int[] generateValueArray(int N, int V) {
int[] ans = new int[N];
for (int i = 0; i < N; i++) {
ans[i] = (int) (Math.random() * V) + 1;
}
return ans;
}
// 对数器
public static void main(String[] args) {
int N = 50000;
int V = 100000;
int[] father = generateFatherArray(N);
int[] values = generateValueArray(N, V);
TreeChain tc = new TreeChain(father, values);
Right right = new Right(father, values);
int testTime = 1000000;
System.out.println("测试开始");
for (int i = 0; i < testTime; i++) {
double decision = Math.random();
if (decision < 0.25) {
int head = (int) (Math.random() * N);
int value = (int) (Math.random() * V);
tc.addSubtree(head, value);
right.addSubtree(head, value);
} else if (decision < 0.5) {
int head = (int) (Math.random() * N);
if (tc.querySubtree(head) != right.querySubtree(head)) {
System.out.println("出错了!");
}
} else if (decision < 0.75) {
int a = (int) (Math.random() * N);
int b = (int) (Math.random() * N);
int value = (int) (Math.random() * V);
tc.addChain(a, b, value);
right.addChain(a, b, value);
} else {
int a = (int) (Math.random() * N);
int b = (int) (Math.random() * N);
if (tc.queryChain(a, b) != right.queryChain(a, b)) {
System.out.println("出错了!");
}
}
}
System.out.println("测试结束");
}
}