Skip to content

分块训练

CF2043G

简要题意:

给你一个长为 \(n\) 的序列,\(q\) 次询问,询问是单点修改或查询区间有多少对数相同。

思路:

按照块长为 \(B\) 分块,注意到我们只关心数的类型而不关心相对位置,所以可以认为每块是一个多重集,修改操作转化为插入和删除某个元素。 这也方便我们初始化:执行 \(n\) 次插入操作即可。

先思考如何解决如下询问:

\(l = (i-1)*B+1,r=j*B\)\(l\) 恰为某个块的左端点,\(r\) 恰为某个块的右端点。

不妨设这种情况下的答案为 \(ans[i][j]\)

对于向第 \(x\) 块中添加一个元素 \(val\),会对满足 \(i\leq x\)\(j\geq x\)\(ans[i][j]\) 产生如下影响:

\[ans[i][j] = ans[i][j] + cnt(i,j,val) = ans[i][j] + cnt(i,x,val)+cnt(x+1,j,val)\]

其中 \(cnt(i,j,val)\) 表示第 \(i\) 块到第 \(j\) 块值 \(val\) 的个数。

这里的第二个等号是个巧妙的转化,他将式中 \(l\)\(r\) 分离,让我们可以分开来处理 2 部分。

对于 \(cnt(i,x,val)\) 这一部分,考虑创建 \(lans[i][x]\),然后直接让 \(lans[i][x]+=cnt(i,x,val)\) ,表示最左边块为 \(i\) 则最右边块 \(\geq x\) 才会受到这部分影响。

对于 \(cnt(x+1,j,val)\) 这一部分,考虑创建 \(rans[j][x]\),然后直接让 \(rans[j][x]+=cnt(x+1,j,val)\) 。表示最右边块为 \(j\) 则最左边块 \(\leq x\) 才会受到这部分影响。

\(ans[i][j]=\sum_{k=l}^{k=r} lans[i][k]+rans[j][k]\)

然后是如何求解 \(cnt(i,j,val)\) ,只需要对每个 \(val\) 维护一个关于块的前缀和数组即可。

\(pre[val][i]=\sum_{j=1}^{i*B}[a[j]=val]\)\(cnt(i,j,val)=pre[val][j]-pre[val][i-1]\)

对于修改操作暴力更新 \(lans,rans,pre\) ,时间复杂度为 \(O(n/B)\)

对于查询暴力枚举 \(lans,rans\) ,时间复杂度为 \(O(n/B)\)

然后考虑一般查询的情形:其实就是添加了 \(O(B)\) 个零散的值,维护一个桶,暴力扫一遍零散的值可以求出零散值之间的贡献,同时通过 \(cnt(i,j,val)\) 也可以很容易求出零散值和整块的贡献。

时间复杂度为 \(O(n^2/B+q(n/B+B))\)。取 \(B=\sqrt n\) 。即为 \(O(n\sqrt n+q\sqrt n)\)

代码
#include<bits/stdc++.h>
#define ll long long

using namespace std;

const int N = 1e5 + 10, M = 510;

int pre[N][M], n, B, a[N], tot, q;
int ln[M], rn[M], pos[N], cnt[N];
ll l[M][M], r[M][M];

int ask(int l, int r, int x) {
    return pre[x][r] - pre[x][l - 1];
}

ll query(int lq, int rq)
{
    ll res = 0;
    for (int i = lq; i <= rq; i++)res += l[lq][i] + r[rq][i];
    return res;
}

void add(int i, int x)
{
    for (int j = i; j; j--)l[j][i] += ask(j, i, x);
    for (int j = i + 1; j <= tot; j++)r[j][i] += ask(i + 1, j, x);
    for (int j = i; j <= tot; j++)pre[x][j]++;
}

void del(int i, int x)
{
    for (int j = i; j <= tot; j++)pre[x][j]--;
    for (int j = i; j; j--)l[j][i] -= ask(j, i, x);
    for (int j = i + 1; j <= tot; j++)r[j][i] -= ask(i + 1, j, x);
}

void solve()
{
    cin >> n;
    B = sqrt(n), tot = (n - 1) / B + 1;

    for (int i = 1; i <= tot; i++)ln[i] = rn[i - 1] + 1, rn[i] = i * B;
    rn[tot] = n;

    for (int i = 1; i <= tot; i++)
        for (int j = ln[i]; j <= rn[i]; j++)pos[j] = i;
    for (int i = 1; i <= n; i++)cin >> a[i], add(pos[i], a[i]);

    cin >> q;

    ll last = 0;
    while (q--)
    {
        int op, lq, rq; cin >> op >> lq >> rq;
        lq = (lq + last) % n + 1;
        rq = (rq + last) % n + 1;
        if (op == 1) {
            del(pos[lq], a[lq]);
            a[lq] = rq;
            add(pos[lq], a[lq]);
            continue;
        }
        if (lq > rq)swap(lq, rq);
        //cout << lq << ' ' << rq << endl;
        ll res = 0;
        if (pos[lq] + 1 >= pos[rq])
        {
            for (int i = lq; i <= rq; i++)res += cnt[a[i]], cnt[a[i]]++;
            for (int i = lq; i <= rq; i++)cnt[a[i]]--;
        }
        else {
            //cout << "!\n";
            res = query(pos[lq] + 1, pos[rq] - 1);
            for (int i = lq; i <= rn[pos[lq]]; i++)res += cnt[a[i]], cnt[a[i]]++, res += ask(pos[lq] + 1, pos[rq] - 1, a[i]);//cout<<i<<"#\n";
            for (int i = ln[pos[rq]]; i <= rq; i++)res += cnt[a[i]], cnt[a[i]]++, res += ask(pos[lq] + 1, pos[rq] - 1, a[i]);//cout<<i<<"#\n";

            for (int i = lq; i <= rn[pos[lq]]; i++)cnt[a[i]]--;
            for (int i = ln[pos[rq]]; i <= rq; i++)cnt[a[i]]--;

        }
        last = 1ll * (rq - lq + 1) * (rq - lq) / 2 - res;
        cout << last << ' ';
    }
}

int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int t = 1;
    while (t--)solve();
    return 0;
}