题目大意:
给你 n,x_0,a,b,p,x_i=(a\cdot x_{i−1}+b)\bmod p。给你 q 组询问 v,问你各个 x_i(i\in[0,n-1]) 中最小的 i 使得 x_i=v 。(1\leq n≤10^{18},\ 0\leq x_p,a,b
推柿子得到: a^n=\frac{v(a-1)+b}{x_0(a-1)+b} 处理掉各种特殊情况后,显然可以 BSGS。由于 10^3 组询问和 10^9 的质数处理范围,我们希望将复杂度均摊到 O(10^6) 级别。因此设 m=1000,y=\frac{p}{m} ,预处理时枚举 i\in[0,y-1] 内的 a^i 并用哈希表记录,每次询问只需要枚举 k\in[1,m] 的 a^{ky} ,复杂度被均摊。分析:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
long long n; int x0, a, b, p;
const int m = 1000;
int T;
int ksm(int a, int b, int c) {
int ret = 1;
for(; b; b >>= 1, a = 1LL * a * a % c) {
if(b & 1) ret = 1LL * ret * a % c;
}
return ret;
}
int ni(int a, int p) {
return ksm(a, p-2, p);
}
#define HASHMOD 233333
#define HASHSIZE 1000000
static struct HashTable
{
int head[HASHMOD];
int key[HASHSIZE + 10], val[HASHSIZE + 10], nxt[HASHSIZE + 10], cnt;
void clear()
{
cnt = 0;
memset(head, 0, sizeof(head));
}
bool count(const int k)
{
for (int i = head[k % HASHMOD]; i; i = nxt[i])
if (key[i] == k) return true;
return false;
}
int operator[](const int k)
{
for (int i = head[k % HASHMOD]; i; i = nxt[i])
if (key[i] == k) return val[i];
return 0;
}
void add(int k, int v)
{
int p = k % HASHMOD;
nxt[++cnt] = head[p]; key[cnt] = k;
val[cnt] = v; head[p] = cnt;
}
}tb;
int ff, yy;
void init(int a, int p) {
int f = 1, y = ceil(1.0 * p / m);
ff = ksm(a, y, p); yy = y;
for(int i = 0; i < y; i++) {
tb.add(f, i);
f = 1LL * f * a % p;
}
}
int bsgs(int a, int b, int p) {
if(b == 1) return 0;
int f = ff, y = yy, tmp = f;
// 仿照上面的第二种思路,f变成a^y*ni(b)并赋值给tmp
f = 1LL * f * ni(b, p) % p;
for(int i = 1; i <= m; i++) {
if(tb[f]) return i * y - tb[f];
f = 1LL * f * tmp % p;
}
return -1;
}
int main()
{
scanf("%d", &T);
while(T--) {
tb.clear();
scanf("%lld%d%d%d%d", &n, &x0, &a, &b, &p);
int q; scanf("%d", &q);
if(a != 0 && a != 1) init(a, p);
while(q--)
{
int v; scanf("%d", &v);
int ans;
if(a == 0) {
if(v == x0) puts("0");
else if(v == b) {
if(n == 1) puts("-1");
else puts("1");
}
else puts("-1");
}
else if(a == 1) {
if (b == 0) {
if(x0 == v) ans = 0;
else ans = -1;
}
else {
ans = 1LL * (v - x0 + p) * ni(b, p) % p;
if(ans >= n) ans = -1;
}
printf("%d\n", ans);
}
else {
int w = (b + 1LL * a * x0 % p - x0 + p) % p;
v = (1LL * v * (a-1) % p + b) % p;
if(w == 0) {
if(v == 0) puts("0");
else puts("-1");
}
else {
v = 1LL * v * ni(w,p) % p;
if(!v) ans = -1;
else ans = bsgs(a, v, p);
// 这里已保证a>=2,b>=1,可以放心bsgs
if(ans > n) puts("-1");
else printf("%d\n", ans);
}
}
}
}
return 0;
}
评论