1 条题解

  • 1
    @ 2023-2-18 16:44:37
    #include <bits/stdc++.h>
    using namespace std;
    const int MAXN = 30000 + 5;
    const int INF = 30001;
    int n;
    vector<int> e[MAXN];
    int w[MAXN]; // 每个点的权值
    // 每个点的:父节点、深度、大小、重子节点
    int fa[MAXN], dep[MAXN], siz[MAXN], hson[MAXN];
    void dfs_build(int u, int fat)
    {
        hson[u] = 0;
        siz[u] = 1;
        for (int i = 0; i < e[u].size(); i++)
        {
            int v = e[u][i];
            if (v == fat)
                continue;
            dep[v] = dep[u] + 1;
            fa[v] = u;
            dfs_build(v, u);
            siz[u] += siz[v];
            if (siz[v] > siz[hson[u]])
                hson[u] = v;
        }
    }
    // 每个点的:所在链的链顶、重边优先的 dfs 序、dfs序对应的节点编号
    int tot, top[MAXN], dfn[MAXN], rnk[MAXN];
    void dfs_div(int u, int fa)
    {
        dfn[u] = ++tot;
        rnk[tot] = u;
        if (hson[u])
        {
            top[hson[u]] = top[u];
            dfs_div(hson[u], u);
            for (int i = 0; i < e[u].size(); i++)
            {
                int v = e[u][i];
                if (v == fa || v == hson[u])
                    continue;
                top[v] = v;
                dfs_div(v, u);
            }
        }
    }
    struct SegTree
    {
        int sum[MAXN * 4], maxx[MAXN * 4];
        void build(int o, int l, int r)
        {
            if (l == r)
            {
                sum[o] = maxx[o] = w[rnk[l]];
                return;
            }
            int mid = (l + r) >> 1;
            build(o * 2, l, mid);
            build(o * 2 + 1, mid + 1, r);
            sum[o] = sum[o * 2] + sum[o * 2 + 1];
            maxx[o] = max(maxx[o * 2], maxx[o * 2 + 1]);
        }
        int query_max(int o, int l, int r, int ql, int qr)
        {
            if (l > qr || r < ql)
                return -INF;
            if (ql <= l && r <= qr)
                return maxx[o];
            int mid = (l + r) >> 1;
            return max(query_max(o * 2, l, mid, ql, qr), query_max(o * 2 + 1, mid + 1, r, ql, qr));
        }
        int query_sum(int o, int l, int r, int ql, int qr)
        {
            if (l > qr || r < ql)
                return 0;
            if (ql <= l && r <= qr)
                return sum[o];
            int mid = (l + r) >> 1;
            return query_sum(o * 2, l, mid, ql, qr) + query_sum(o * 2 + 1, mid + 1, r, ql, qr);
        }
        void update(int o, int l, int r, int x, int t)
        {
            if (l == r)
            {
                maxx[o] = sum[o] = t;
                return;
            }
            int mid = (l + r) >> 1;
            if (x <= mid)
                update(o * 2, l, mid, x, t); // 左右分别更新
            else
                update(o * 2 + 1, mid + 1, r, x, t);
            sum[o] = sum[o * 2] + sum[o * 2 + 1];
            maxx[o] = max(maxx[o * 2], maxx[o * 2 + 1]);
        }
    } st;
    // 求u~v路径上权值的最大值
    int qmax(int u, int v)
    {
        int res = -30001;
        while (top[u] != top[v])
        {
            if (dep[top[u]] < dep[top[v]])
            {
                // v~top[v] 计入答案中
                res = max(res, st.query_max(1, 1, tot, dfn[top[v]], dfn[v]));
                v = fa[top[v]];
            }
            else
            {
                // u~top[u] 计入答案中
                res = max(res, st.query_max(1, 1, tot, dfn[top[u]], dfn[u]));
                u = fa[top[u]];
            }
        }
        if (dep[u] > dep[v])
            swap(u, v);
        res = max(res, st.query_max(1, 1, tot, dfn[u], dfn[v]));
        return res;
    }
    // 求u~v路径上权值的和
    int qsum(int u, int v)
    {
        int res = 0;
        while (top[u] != top[v])
        {
            if (dep[top[u]] < dep[top[v]])
            {
                // v~top[v] 计入答案中
                res = res + st.query_sum(1, 1, tot, dfn[top[v]], dfn[v]);
                v = fa[top[v]];
            }
            else
            {
                // u~top[u] 计入答案中
                res = res + st.query_sum(1, 1, tot, dfn[top[u]], dfn[u]);
                u = fa[top[u]];
            }
        }
        if (dep[u] > dep[v])
            swap(u, v);
        res = res + st.query_sum(1, 1, tot, dfn[u], dfn[v]);
        return res;
    }
    int main()
    {
        ios::sync_with_stdio(false);
        cin.tie(0);
        cin >> n;
        for (int i = 1; i <= n - 1; i++)
        {
            int u, v;
            cin >> u >> v;
            e[u].push_back(v);
            e[v].push_back(u);
        }
        for (int i = 1; i <= n; i++)
            cin >> w[i];
        dep[1] = 1;
        fa[1] = 0;
        dfs_build(1, 0);
        tot = 0;
        top[1] = 1;
        dfs_div(1, 0);
        //--------
        st.build(1, 1, tot);
        int q;
        cin >> q;
        while (q--)
        {
            string op;
            int x, y;
            cin >> op >> x >> y;
            if (op == "CHANGE")
            {
                st.update(1, 1, tot, dfn[x], y);
            }
            else if (op == "QMAX")
            {
                cout << qmax(x, y) << "\n";
            }
            else if (op == "QSUM")
            {
                cout << qsum(x, y) << "\n";
            }
        }
    
        return 0;
    }
    
    • 1

    信息

    ID
    790
    时间
    1000ms
    内存
    512MiB
    难度
    6
    标签
    递交数
    52
    已通过
    16
    上传者