let menu = ["Home", "Algorithms", "CodeHub", "VNOI Statistics"];

Overview

VMSALARY - Cây tiền lương

solution.md

Mấu chốt của bài này là với mỗi đỉnh u, ta cần tìm xem trong cây con có gốc u có bao nhiêu đỉnh có giá trị nhỏ hơn u.

Để làm được điều trên, ta làm phẳng cây kết hợp với BIT. Cụ thể như sau:

Làm phẳng cây

Duyệt DFS trên cây, khi đi vào đỉnh, ta thêm đỉnh đó vào cuối mảng và sau khi đã DFS xong tất cả các con của nó, ta cũng thêm đỉnh đó vào cuối mảng. Như vậy ta thấy, trên mảng được tạo thành tất cả các con của đỉnh \(u\) đều nằm giữa 2 vị trí mà \(u\) xuất hiện.

Ví dụ, cây trong đề bài là:

  1
 / \
2   3
   / \
  4   5

Mảng khi làm phẳng cây là:

1 2 2 3 4 4 5 5 3 1

Khi cài đặt ,ta lưu chỉ 2 số mà đỉnh xuất hiện trên mảng:

int count = 0;
int L[100000], R[100000];
void dfs(int u) {
    L[u] = ++count;
    for (int v: con[u]) dfs[v];
    R[u] = ++count;
}

L lưu vị trí đầu tiên đỉnh xuất hiện và R lưu vị trí cuối cùng.

Sử dụng BIT để truy vấn

Sau khi đã làm phẳng cây, như đã nói ở trên thì với mỗi đỉnh \(u\), ta có một đoạn liên tiếp từ \(L[u]\) đến \(R[u]\) chỉ gồm các đỉnh nằm trong cây con gốc \(u\). Đến đây bài toán trở nên đơn giản.

Ban đầu có một mảng \(A\) gồm \(2N\) phần tử \(0\), ta duyệt các đỉnh theo tiền lương tăng dần, ở mỗi bước ta tăng \(A[L[v]]\) với \(v\) là đỉnh có tiền lương nhỏ hơn đỉnh đang xét. Sau đó truy vấn tính tổng từ \(A[L[u]]\) đến \(A[R[u]]\) để tìm số đỉnh trong cây con gốc \(u\) có lương nhỏ hơn \(u\), tăng kết quả lên một lượng \(c\times(c-1) / 2\).

main.cpp
Open in Github Download
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

typedef vector<vector<int>> dsk;

struct bit {
    int n;
    vector<int> f;
    bit(int n): n(n), f(n+1,0) {}
    void up(int i) {
        for (; i<=n; i += i&-i) f[i]++;
    }
    int get(int i) {
        int r = 0;
        for (; i>0; i -= i&-i) r += f[i];
        return r;
    }
    int get(int l, int r) { return get(r) - get(l-1); }
};

int top = 0;
int L[100000], R[100000];

void dfs(int u, const dsk &con) {
    L[u] = ++top;
    for (int v: con[u]) dfs(v, con);
    R[u] = ++top;
}

int main() {
    ios::sync_with_stdio(false); cin.tie(0);

    int n; cin >> n;
    dsk con(n);
    vector<pair<int, int>> C(n);
    cin >> C[0].first;
    for (int i=1; i<n; i++) {
        int cha;
        cin >> cha >> C[i].first;
        con[cha-1].push_back(i);
    }
    dfs(0, con);

    for (int i=0; i<n; i++) C[i].second = i;
    sort(C.begin(), C.end());

    bit bit(2*n);
    long long res = 0;
    for (int i=1, k=0; i<n; i++) {
        for (; C[k].first < C[i].first; k++) bit.up(L[C[k].second]);
        int u = C[i].second;
        long long count = bit.get(L[u], R[u]);
        res += count * (count-1) / 2;
    }

    cout << res;
    return 0;
}
Comments