「HDU-6059」Kanade's trio (Trie树)

「HDU-6059」Kanade’s trio
给定一个序列,求序列中满足i<j<k且(A[i] xor A[j])<(A[j] xor A[k])的三元组的数量

题意

给定一个长度为$n(∑n≤5∗10^5)$的序列,求序列中满足$i<j<k$且$(A[i]\ xor\ A[j])<(A[j]\ xor\ A[k])$的三元组的数量,其中$0≤A[i]<2^{30}$.

解法

对于每个数进行二进制拆分,依次插入字典树处理。

首先考虑两个数$A[i],A[k]$,其二进制位前若干位相等,第$p$位不同。此时若要选取满足条件的数$A[j]$,则$A[j]$的前$p-1$位取任意值都不影响结果,而对于第$p$位,要使$(A[i]\ xor\ A[j])<(A[j]\ xor\ A[k])$,有如下两种情况:

  • $A[i]_p=1, A[k]_p=0​$,此时$A[j]_p​$取值为$1​$
  • $A[i]_p=0, A[k]_p=1$,此时$A[j]_p$取值为$0$

即$A[j]_p$的取值为$A[j]_p=A[i]_p$

具体到插入字典树的过程中,对于$A[k]$插入过程中的每一位,判断是否存在二进制位为$!A[k]_p$的字典树节点$pos=trie[now][!A[k]_p]$,如果有,在计数过程中考虑如下两种情况:

  1. $i,j$均为结点$pos$子树下的节点:

    此时方案数为$C\binom{2}{sz[pos]} = sz[pos]×(sz[pos]-1)/2$,其中$sz[pos]$表示结点$pos$的子树大小,该选取方案可以保证$i,j$有序;

  2. $i$为结点$pos$子树下的节点,$j$为不在结点$pos$下,且第$p$位二进制数与$i$相同的节点:

    利用数组$cnt[p][2]​$统计第$p​$个二进制位上为0/1的数的数量,当前方案数即为$(cnt[p][!A[k]_p]-sz[pos])×sz[pos]​$,但注意这种计数方案没有保证$i>j​$的方案数;

    我们考虑一个数$A[i]$,在当前位$p$插入字典树时,有$cnt[p][A[i]_p]-sz[pos]$个数在选取时满足$j<i$的情况,对该值求和即为在统计该节点$pos$时需要减去的方案数。

    故第二种情况的方案数为$(cnt[p][!A[k]_p]-sz[pos])×sz[pos]-sum[pos]$,其中$sum$为$pos$位插入过程中$cnt[p][A[i]_p]-sz[pos]$的和;

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

const int maxn = 31 *(5e5 + 10);

int trie[maxn][2], tot;
int sz[maxn], ext[maxn];
int cnt[31][2];

ll ans;

void init()
{
memset(trie, 0, sizeof trie);
memset(sz, 0, sizeof sz);
memset(ext, 0, sizeof ext);
memset(cnt, 0, sizeof cnt);
ans = tot = 0;
}

void insert_ch(int x)
{
int root = 0;
for(int i = 29; i >= 0; i--)
{
int id = (x >> i) & 1;
if(!trie[root][id]) trie[root][id] = ++tot;
cnt[i][id]++;
if(trie[root][id ^ 1])
{
int now = trie[root][id ^ 1];
ans += 1ll * sz[now] * (sz[now] - 1) / 2;
ans += 1ll * (cnt[i][1 ^ id] - sz[now]) * sz[now] - ext[now];
}
root = trie[root][id];
sz[root]++;
ext[root] += cnt[i][id] - sz[root];
}
}

int main()
{
int t, n, x;
scanf("%d", &t);
while(t--)
{
init();
scanf("%d", &n);
for(int i = 0; i < n; i++)
{
scanf("%d", &x);
insert_ch(x);
}
printf("%lld\n", ans);
}
return 0;
}