逃离地球的博客

三维偏序学习笔记

2020-07-03 · 7 min read
学习笔记

本来早就想写了,但文化课要考试,所以就拖到现在了😵

题目链接

题目大意

nn 个元素,第 ii 个元素 pip_i 有三个属性 (ai,bi,ci)(a_i,b_i,c_i),设 f(i)f(i) 表示满足 ajaia_j \leq a_ibjbib_j \leq b_icjcic_j \leq c_ijij \ne ijj 的数量。求每一个 f(i)f(i).

题目分析

先考虑更简单的问题:

对于一维偏序,只需排序后询问在数列中的位置即可。

对于二维偏序,首先将元素以 aa 为关键字排序。依次遍历元素,建立数组 tt,当遍历到第 ii 个元素时,设 tkt_k 表示满足 ajaia_j\leq a_ibj=kb_j=kjj 的个数,那么显然有 f(i)=j=1bitjf(i)=\sum_{j=1}^{b_i}t_j。每次遍历完元素 pip_i,由于 aia_i 一定小于等于后面遍历的元素的 aa 值,所以只需把 tbit_{b_i} 加一即可。观察到 tt 数组涉及到单点修改和区间查询两个操作,可以用树状数组实现。(由于会出现 aa 值相等的两个元素,所以在对所有元素排序时,还要以 bb 为第二关键字排序)

对于三维偏序,也是先将元素以 aa 为关键字排序。然后将排好序的元素进行分治,设当前分治的区间为 [L,R][L,R],那么对于区间中两个满足偏序条件的元素 pip_ipjp_j (i<j)(i<j) 来说,有三种位置情况:

  1. Li<jmidL\leq i<j\leq mid

  2. midi<jRmid\leq i<j\leq R

  3. LimidjRL\leq i\leq mid\leq j\leq R

暂时只用考虑情况三,因为只要对于每个子区间都统计情况三,那么情况一和情况二自然会在 [L,mid][L,mid][mid+1,R][mid+1,R] 及其子区间中考虑到。

那么现在问题就变为了,对于 i[L,mid]i\in [L,mid],统计有几个 j[mid+1,R]j\in[mid+1,R],满足 bibjb_i\le b_jcicjc_i\le c_j. 和二维偏序问题比较类似。

把区间 [L,mid][L,mid] 和区间 [mid+1,R][mid+1,R] 分别以 bb 为关键字排序。建立两个指针 ptr1 和 ptr2,ptr1 最初指向 LL,ptr2 最初指向 mid+1mid+1. 每次先将 ptr1 右移至满足 bptr1bptr2b_{ptr1}\le b_{ptr2} 的最后一个位置,且将经过的所有元素像二维偏序的处理方法一样插入树状数组,再用二维偏序的处理方法,把 f(ptr2)f(ptr2) 加上该数值,并将 ptr2 右移一位。直到遍历完所有元素。

然后再分治处理 [L,mid][L,mid][mid+1,R][mid+1,R] 两个子区间即可。

时间复杂度为 O(nlog2n)\mathcal{O}(n\log^2n).

代码实现

1. 一些细节问题
  • 由于可能出现相同的的两个元素,可以把序列去重,然后给每一个元素一个权值。

  • 在清零树状数组时不要用 memset ,而是逐一单点修改清零,否则会导致时间复杂度错误。

2. 代码
#include <bits/stdc++.h>
using namespace std;

#define tiii tuple<int, int, int>
#define mt make_tuple

int read() {
    int ret = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') f = (ch == '-') ? -f : f, ch = getchar();
    while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = getchar();
    return ret * f;
}

int n, k, tot = 0, ans[100005];

int t[200005];
int lowbit(int x) { return x & (-x); }
void modify(int x, int y) {
    for (int i = x; i <= k; i += lowbit(i)) t[i] += y;
}
int query(int x) {
    int sum = 0;
    for (int i = x; i > 0; i -= lowbit(i)) sum += t[i];
    return sum;
}

struct Node {
    int a, b, c, cnt, ans;
} p[100005];

map<tiii, int> q;

bool cmp1(Node a, Node b) {
    if (a.a == b.a) {
        if (a.b == b.b) return a.c < b.c;
        return a.b < b.b;
    }
    return a.a < b.a;
}

bool cmp2(Node a, Node b) {
    if (a.b == b.b) return a.c < b.c;
    return a.b < b.b;
}

void cdq(int l, int r) {
    if (l == r) return;
    int mid = (l + r) / 2;
    cdq(l, mid), cdq(mid + 1, r);
    sort(p + l, p + mid+1, cmp2), sort(p + mid + 1, p + r+1, cmp2);
    
    int ptr1 = l, ptr2 = mid + 1;
    while (ptr1 != mid + 1 || ptr2 != r + 1) {
        if ((p[ptr1].b <= p[ptr2].b && ptr1 != mid + 1) || ptr2 == r + 1)
            modify(p[ptr1].c, p[ptr1].cnt), ++ptr1;
        else
            p[ptr2].ans += query(p[ptr2].c), ++ptr2;
    }
    for (int i = l; i <= mid; ++i) modify(p[i].c, -p[i].cnt);
}

signed main() {
    n = read(), k = read();
    for (int i = 1; i <= n; ++i) {
        int a = read(), b = read(), c = read();
        tiii l = mt(a, b, c);
        if (!q[l])
            p[++tot] = (Node){a, b, c, 1, 0}, q[l] = tot;
        else
            p[q[l]].cnt++;
        // 去重
    }
    sort(p + 1, p + tot + 1, cmp1);

    cdq(1, tot);

    for (int i = 1; i <= n; ++i)
        ans[p[i].ans + p[i].cnt - 1]+=p[i].cnt;

    for (int i = 0; i < n; ++i) printf("%d\n", ans[i]);
    return 0;
}
本站总访问量