由题可知这是一棵树,因此求每一条边经过的次数可以通过树上差分解决。而现在只有部分(有向)边需要花费,因此就需要找到一种能够记录单向花费的信息。考虑到一条边连接的两个点因在树上而深度不同,所以可以分为叶子指向父节点与父节点指向叶子两种边。形象化地,第一种可以称为上行边,第二种称为下行边。
首先通过树上差分,将信息分别储存在上行边的起点与下行边的中点。
1 2 3 4 5 6
| for (int i = 1;i <= k;++i) { int lca = getLCA (a[i],a[i - 1]); ++up[a[i - 1]];++down[a[i]]; --up[lca];--down[lca]; }
|
然后通过一次 dfs 和差分数组还原每一条边经过的次数。最后统计答案,对于一条 $v \to u$ 需要花费的边,若是上行边,那么通过的次数记录在 $v$ 上,否则是 $u$ 上。设某条边经过了 $x$ 次,则总花费为 $1 + 2 + \cdots + 2^{x - 1} = 2^x - 1$,所以直接预处理 $2$ 的幂次即可。
总代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| #include <iostream> #include <cstdio> #include <algorithm> #include <cmath> #include <cstring> #include <vector> #include <map> #define init(x) memset (x,0,sizeof (x)) #define ll long long #define ull unsigned long long #define INF 0x3f3f3f3f using namespace std; const int MAX = 1e5 + 5; const int MOD = 1e9 + 7; inline int read (); int n,k,a[MAX * 10],f[MAX][20],dep[MAX],up[MAX],down[MAX]; vector <int> ve[MAX]; map <pair <int,int>,int> co; ll ans,pw[MAX * 10]; int getLCA (int x,int y); void dfs (int u,int fa); int main () { n = read (); for (int i = 1;i < n;++i) { int x = read (),y = read (),ty = read (); ve[x].push_back (y);co[{x,y}] = 0; ve[y].push_back (x);co[{y,x}] = ty; } k = read (); for (int i = 1;i <= k;++i) a[i] = read (); dfs (1,0); a[0] = pw[0] = 1; for (int i = 1;i <= k;++i) pw[i] = pw[i - 1] * 2 % MOD; for (int i = 1;i <= k;++i) { int lca = getLCA (a[i],a[i - 1]); ++up[a[i - 1]];++down[a[i]]; --up[lca];--down[lca]; } dfs (1,0); for (int u = 1;u <= n;++u) { for (auto v : ve[u]) { if (!co[{v,u}]) continue; if (dep[u] < dep[v]) ans = (ans + pw[up[v]] - 1 + MOD) % MOD; else ans = (ans + pw[down[u]] - 1 + MOD) % MOD; } } printf ("%lld\n",ans); return 0; } inline int read () { int s = 0;int f = 1; char ch = getchar (); while ((ch < '0' || ch > '9') && ch != EOF) { if (ch == '-') f = -1; ch = getchar (); } while (ch >= '0' && ch <= '9') { s = s * 10 + ch - '0'; ch = getchar (); } return s * f; } void dfs (int u,int fa) { f[u][0] = fa;dep[u] = dep[fa] + 1; for (int i = 1;i <= 18;++i) f[u][i] = f[f[u][i - 1]][i - 1]; for (auto v : ve[u]) { if (v == fa) continue; dfs (v,u); up[u] += up[v];down[u] += down[v]; } } int getLCA (int x,int y) { if (dep[x] < dep[y]) swap (x,y); for (int i = 18;~i;--i) { if (dep[f[x][i]] >= dep[y]) x = f[x][i]; if (x == y) return x; } for (int i = 18;~i;--i) if (f[x][i] != f[y][i]) x = f[x][i],y = f[y][i]; return f[x][0]; }
|