树上三元组数量

2022-05-06

题目来源:第十九届重庆大学程序设计竞赛 I 题。

题意

给一棵 个节点的树,定义节点之间的距离为它们之间的最短路径上边的数量。 要求计算满足两两之间距离相等的节点三元组的数量(不考虑顺序)。

思路

考虑三个节点的最近公共祖先 。 显然,三个节点不可能同时位于 的同一棵子树上——这与最近公共祖先的定义矛盾。 换而言之,至少有一个节点独占了 的一棵子树,假设它与 的距离为

假设 的第 棵子树为 ; 它对答案有贡献的有如下两种情况:

  1. 贡献了三个节点中的两个,而 的前 棵子树贡献了另外一个节点;

  2. 贡献了三个节点中的一个,而 的前 棵子树贡献了剩下两个节点。

因此,最终答案 为:

上式中, 表示节点 的前 棵子树提供两个节点,并且第三个节点距离 的方案数量; 表示节点 的前 棵子树中距离 的节点数量。

显然, 的计算非常容易:

时,有初始值

接下来考虑如何计算 。 除了继承自 的部分之外, 的贡献共有两部分:

  1. 的前 棵子树分别提供了一个节点,此时 就是三个节点的中心,三个节点与 的距离均为 ,彼此之间的距离为

  2. 独自提供了两个节点,此时第三个节点距离 的距离为

综上所述:

优化

上述算法的所有转移都是 的,因此时间复杂度和空间复杂度均等于状态数量。 每个节点都一一对应一组 ,而每组 对总复杂度的贡献为对应 的最大值,等于 的高度。 因此总复杂度等于所有节点的高度之和,当整棵树退化成一条链时,取到上界

但是,可以考虑如下特殊情况:

注意到这次转移实际上相当于一次整体“平移”,实际上可以使用一些指针技巧在 而不是子树的高度的时间内一次性地计算它们全部。 进一步,如果把高度最大的子树安排在 的位置,这种优化的收益就可以最大化。

考虑这种情况下,每个节点对总复杂度的贡献:

  1. 该节点其父节点的所有子树中高度最大的,此时它对总复杂度的贡献为

  2. 该节点不是父节点的所有子树中高度最大的,此时它对总复杂度的贡献为子树的高度,但这同时意味着至少有相同数量的节点属于前一种情况,因此每个节点的均摊复杂度仍然是 的。

综上所述,算法的时间和空间复杂度为

实现

注意 的第三个维度()可以就地更新。

#include <bits/stdc++.h>

const int N = 1e5 + 8;

static int n, dep[N], son[N];
static long long answer;
static long long allocated[100 * N], *allo_head = allocated;
static long long *dp_f[N], *dp_g[N];
static std::vector<int> edges[N];

int solve_depths(int u, int f)
{
    dep[u] = dep[f] + 1;
    son[u] = u;
    for (int v : edges[u])
        if (v != f && solve_depths(v, u) > dep[son[u]])
            son[u] = son[v];
    for (int v : edges[u])
        if (u == 1 || son[v] != son[u])
        {
            int size = dep[son[v]] - dep[u] + 8;
            dp_f[son[v]] = allo_head;
            allo_head += size * 3;
            dp_g[son[v]] = allo_head;
            allo_head += size * 1;
        }
    return dep[son[u]];
}

void solve_answer(int u, int f)
{
    for (int v : edges[u])
        if (v != f)
        {
            solve_answer(v, u);
            if (son[v] == son[u])
            {
                dp_f[u] = dp_f[v] + 1;
                dp_g[u] = dp_g[v] - 1;
            }
        }

    auto &fu = dp_f[u];
    auto &gu = dp_g[u];

    answer += fu[0];
    gu[0] = 1;
    for (int v : edges[u])
    {
        if (v != f && son[v] != son[u])
        {
            int size = dep[son[v]] - dep[u];
            auto &fv = dp_f[v];
            auto &gv = dp_g[v];

            for (int i = 0; i < size; i++)
                answer += fv[i + 1] * gu[i] + fu[i + 1] * gv[i];
            for (int i = 0; i < size; i++)
            {
                fu[i] += fv[i + 1];
                fu[i + 1] += gv[i] * gu[i + 1];
                gu[i + 1] += gv[i];
            }
        }
    }
}

int main()
{
    scanf("%d", &n);
    for (int i = 0, u, v; i < n - 1; i++)
    {
        scanf("%d%d", &u, &v);
        edges[u].push_back(v);
        edges[v].push_back(u);
    }

    solve_depths(1, 0);
    solve_answer(1, 0);
    printf("%lld\n", answer);

    return 0;
}