杭电多校第六场 1006 A Very Easy Graph Problem 题解

题目大意

给你一个n个点的无向图,每个点分为0和1两类,每条边的长度为$\large2^i$,求所有不同类型的点对最短路之和。

解题思路

由于每条边的长度为$2^i$,所以前面所有的边长度加起来都没有这条边长。

所以在第i条边加进来的时候,之前联通的点必然是最短路,所以不用管已经联通的点了,剩下的可以用并查集来构成一棵树。

接下来就是对树进行dfs。

定义节点类型,有ty,dp数组和sum。

ty代表该节点的类型。

dp[0] [0]表示以该节点为根节点的子树中所有ty=0的节点到根节点的距离和。

dp[0] [1]表示以该节点为根节点的子树中所有ty=0的节点数量。

dp[1] [0]表示以该节点为根节点的子树中所有ty=1的节点到根节点的距离和。

dp[1] [1]表示以该节点为根节点的子树中所有ty=1的节点数量。

sum为以该节点为根节点的子树中所有不同类型的点对距离和。

dp数组的转移为

1
2
3
4
5
6
7
8
for (int i = head[x]; ~i; i = edge[i].next) {
int v = edge[i].to;
ll w = edge[i].val;
if (v == f) continue; dfs(v, x);
node[x].dp[0][0] = (node[x].dp[0][0] + (node[v].dp[0][0] + w * node[v].dp[0][1] % mod) % mod) % mod;
node[x].dp[1][0] = (node[x].dp[1][0] + (node[v].dp[1][0] + w * node[v].dp[1][1] % mod) % mod) % mod;
node[x].dp[0][1] += node[v].dp[0][1], node[x].dp[1][1] += node[v].dp[1][1];
}

距离的统计为,到同类型到子节点的距离加上其他所有的同类型的点乘以这条边的距离。

数量的统计则直接相加就可以了。

sum的转移为

1
2
3
4
5
6
7
8
for (int i = head[x]; ~i; i = edge[i].next) {
int v = edge[i].to;
ll w = edge[i].val;
if (v == f) continue;
node[x].sum = (node[x].sum + node[v].sum) % mod;
node[x].sum = (node[x].sum + (node[x].dp[0][1] - node[v].dp[0][1]) * (node[v].dp[1][0] + w * node[v].dp[1][1]) % mod) % mod;
node[x].sum = (node[x].sum + (node[x].dp[1][1] - node[v].dp[1][1]) * (node[v].dp[0][0] + w * node[v].dp[0][1]) % mod) % mod;
}

sum的转移比较麻烦。

首先加上子节点的距离。

之后为除去以该子节点为根节点的树中的ty为0的数量乘上以该子节点为根节点的ty为1的点到x的距离。

后面的则为ty=1的转移。

刚开始卡了我好久,认为ty=0乘上ty=1的距离还少了一段ty=0到根节点的距离。

后来仔细想了想,才发现,这里的ty=0到ty=1与后面的ty=1到ty=0两段合起来,就是完整的距离了。

好妙啊。。

其实这题还有其他的做法,像点分治,最小生成树等。

但我没学点分治,就用的这个了。

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
#include <cmath>
#include <stack>
#include <map>
#include <string>
#include <vector>
#include <algorithm>
#include <sstream>
#include <unordered_map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;

const int N = 1e5 + 5, mod = 1e9+7;

int cnt, head[N], fa[N];
struct {
int to, next;
ll val;
}edge[N<<1];
struct {
int ty;
ll dp[2][2], sum;
}node[N];

void build(int u, int v, ll val) {
edge[++cnt].to = v, edge[cnt].val = val;
edge[cnt].next = head[u], head[u] = cnt;
}

ll qpow(ll a, ll b) {
ll ans = 1;
while (b) {
if (b & 1) ans = (ans * a % mod) % mod;
a = (a * a) % mod, b >>= 1;
}
return ans;
}

int find(int son) {
int root = son, t;
while (fa[root] != root) root = fa[root];
while (son != root) {
t = fa[son];
fa[son] = root;
son = t;
}
return root;
}

void dfs(int x, int f) {
node[x].dp[node[x].ty][1] = 1;
for (int i = head[x]; ~i; i = edge[i].next) {
int v = edge[i].to;
ll w = edge[i].val;
if (v == f) continue; dfs(v, x);
node[x].dp[0][0] = (node[x].dp[0][0] + (node[v].dp[0][0] + w * node[v].dp[0][1] % mod) % mod) % mod;
node[x].dp[1][0] = (node[x].dp[1][0] + (node[v].dp[1][0] + w * node[v].dp[1][1] % mod) % mod) % mod;
node[x].dp[0][1] += node[v].dp[0][1], node[x].dp[1][1] += node[v].dp[1][1];
}
for (int i = head[x]; ~i; i = edge[i].next) {
int v = edge[i].to;
ll w = edge[i].val;
if (v == f) continue;
node[x].sum = (node[x].sum + node[v].sum) % mod;
node[x].sum = (node[x].sum + (node[x].dp[0][1] - node[v].dp[0][1]) * (node[v].dp[1][0] + w * node[v].dp[1][1]) % mod) % mod;
node[x].sum = (node[x].sum + (node[x].dp[1][1] - node[v].dp[1][1]) * (node[v].dp[0][0] + w * node[v].dp[0][1]) % mod) % mod;
}
}

int main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int t; cin >> t;
while (t--) {
int n, m;
cin >> n >> m;
cnt = 0;
for (int i = 1; i <= n; i++) {
head[i] = -1; fa[i] = i; cin >> node[i].ty;
node[i].sum = 0; memset(node[i].dp, 0, sizeof(node[i].dp));
}
for (int i = 1; i <= m; i++) {
int u, v; cin >> u >> v;
if (u > v) swap(u, v);
int rt1 = find(u), rt2 = find(v);
if (rt1 == rt2) continue;
ll w = qpow(2, i);
build(u, v, w), build(v, u, w), fa[rt2] = rt1;
}
dfs(1, 0);
cout << node[1].sum << endl;
}
return 0;
}
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • © 2015-2021 sakurakarma
  • Powered by Hexo Theme Ayer
  • PV: UV:

请我喝杯咖啡吧~

支付宝
微信