按道理说这个算法提出之后,应该是有很多语言的实现版本。但是对于多项式的方幂运算,C语言版本总感觉缺少点什么。以前在其他地方见过,但多项式方幂的实现算法复杂度较高。因此自己也实现一个版本。但目前遇到一个问题,就是64bit的两个整数相乘,会溢出。虽然是在模n(也是64bit的)下运算的,但为了避免溢出,好好的乘法非要改用加法实现。希望在这方面可以有改进的地方( Mul64Mod )。
备注:2022年11月1日修订。根据文献[3]的提示,GNU G++17可以使用 unsigned __int128
数据类型,虽然我的编译器( g++.exe (MinGW-W64 x86_64-ucrt-posix-seh, built by Brecht Sanders) 12.2.0 )提示使用的数据类型应该为 __uint128_t
。但起码目前模乘法的速度问题算是基本解决了。另外,我将当前版本保存了一份在网上:https://onlinegdb.com/dFtGoQgL6 或者 https://gist.github.com/tpu01yzx/172c65d3003bb3a09a941e69bc6c370b
#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
#include <memory.h>
#include <math.h>
#define N 100
#define MAX_FACTORS 32
#define MAXR 320
typedef unsigned long long int uint64;
typedef unsigned int uint32;
typedef unsigned char uint8;
uint64 gcd(uint64 a, uint64 b) {
uint64 c = a;
while(b) {
c = a % b;
a = b;
b = c;
}
return a;
}
uint8 BitCount(uint64 n) {
uint8 ans = 0;
while(n) {
n>>=1;
ans++;
}
return ans;
}
uint32 SquareRoot(uint64 x) {
if(x < (1ULL<<32)) {
return (uint32)sqrt((uint32)x);
}
uint8 logx = BitCount(x);
uint32 l = pow(2.0, (double)(logx-1) / 2);
uint32 r = pow(2.0, (double)(logx) / 2);
uint32 m = l;
uint64 m2 = x;
while(l <= r) {
m = (l + r ) / 2;
m2 = m * m;
if(m2 == x) return m;
if(m2 < x) {
l = m + 1;
} else {
r = m - 1;
}
}
return m;
}
uint64 Power(uint32 a, uint8 k) {
if(k == 0) return 1;
uint64 ans = 1;
uint64 a2 = a;
while(k > 1) {
if(k & 0x01) {
ans *= a2;
}
a2 = a2 * a2;
k>>=1;
}
ans *= a2;
return ans;
}
uint32 PowerMod(uint32 a, uint8 k, uint32 mod) {
if(k == 0) return 1;
uint64 ans = 1;
uint64 a2 = a % mod;
while(k > 1) {
if(k & 0x01) {
ans = (ans * a2) % mod;
}
a2 = (a2 * a2) % mod;
k>>=1;
}
ans = (ans * a2) % mod;
return (uint32)ans;
}
uint64 Mul64Mod(uint64 a, uint64 b, uint64 mod) {
return (__uint128_t)a * b % mod;
}
void MulPoly(uint64 *p1, uint64 *p2, uint64 n, uint32 r) {
uint64 ans[MAXR];
int i, j;
memset(ans, 0, sizeof(uint64) * r);
for(i = 0; i < r; i++) {
for(j = 0; j <= i; j++) {
ans[i] += Mul64Mod(p1[j], p2[i - j], n);
if(ans[i] >= n) ans[i] -= n;
}
for(j = i + 1; j < r; j++) {
ans[i] += Mul64Mod(p1[j], p2[r + i - j], n);
if(ans[i] >= n) ans[i] -= n;
}
}
memcpy(p1, ans, sizeof(uint64) * r);
}
void PowerPoly(uint64 *coff, uint64 n, uint32 r) {
if(n == 0) {
memset(coff, 0, sizeof(uint64) * r);
coff[0] = 1ULL;
return;
}
uint64 n0 = n;
uint64 ans[MAXR];
uint64 coff2[MAXR];
memset(ans, 0, sizeof(uint64) * r); ans[0] = 1ULL;
memcpy(coff2, coff, sizeof(uint64) * r);
while(n > 1) {
if(n & 0x01) {
MulPoly(ans, coff2, n0, r);
}
MulPoly(coff2, coff2, n0, r);
n>>=1;
}
MulPoly(ans, coff2, n0, r);
memcpy(coff, ans, sizeof(uint64) * r);
}
int CheckPoly(uint64 *p, uint64 a, uint64 n, uint32 r) {
uint64 n0 = n % r;
uint32 i;
if(p[0] != a) return 0;
for(i = 1; i < n0; i++) {
if(p[i] != 0) return 0;
}
if(p[n0] != 1ULL) return 0;
for(i = n0 + 1; i < r; i++) {
if(p[i] != 0) return 0;
}
return 1;
}
uint32 PerfectRoot(uint8 a, uint64 n, uint8 logn) {
if(a == 1) return n;
uint32 l = pow(2.0, (double)(logn-1) / a);;
uint32 r = pow(2.0, ((double)(logn) / a));
uint32 m;
uint64 mp;
while(l <= r) {
m = (l + r) / 2;
mp = Power(m, a);
if(mp == n) return m;
if(mp < n) {
l = m + 1;
} else {
r = m - 1;
}
}
return 0;
}
int IsPower(uint64 n) {
uint8 i, j;
uint8 cnt = BitCount(n);
for(i = 2; i < cnt; i++) {
if(PerfectRoot(i, n, cnt)) return 1;
}
return 0;
}
uint8 SmallFactors(uint32 r, uint32 *factors, uint32 *exponents) {
uint32 i;
uint32 sqrtr = SquareRoot(r);
uint8 p = 0;
for(i = 2; i <= sqrtr; i++) {
if(r % i == 0) {
factors[p] = i;
exponents[p] = 0;
while(r % i == 0) {
r /= i;
exponents[p]++;
}
p++;
sqrtr = SquareRoot(r);
}
}
if(r > 1) {
factors[p] = r;
exponents[p] = 1;
p++;
}
return p;
}
uint32 SmallOrder(uint32 n, uint32 r) {
uint32 i;
uint32 factors[MAX_FACTORS];
uint32 exponents[MAX_FACTORS];
uint32 p;
uint32 ans;
ans = 1;
p = SmallFactors(r, factors, exponents);
for(i = 0; i < p; i++) {
if(exponents[i] > 1) {
ans *= Power(factors[i], exponents[i] - 1);
}
ans *= factors[i] - 1;
}
p = SmallFactors(ans, factors, exponents);
for(i = 0; i < p; i++) {
while(ans % factors[i] == 0) {
if(PowerMod(n, ans, r) == 1) {
ans /= factors[i];
} else {
break;
}
}
if(ans % factors[i] == 0) {
ans *= factors[i];
}
}
return ans;
}
uint32 FindR(uint64 n) {
uint32 r, k;
uint8 logn = BitCount(n);
uint32 maxr = Power(logn, 5);
uint32 maxk = Power(logn, 2);
if(maxr < 3) maxr = 3;
for(r = 2; r <= maxr; r++) {
if(gcd(n, r) > 1) continue;
k = SmallOrder(n % r, r);
if(k > maxk) break;
}
return r;
}
uint32 SmallPhi(uint32 r) {
uint32 ans;
uint32 i;
uint32 factors[MAX_FACTORS];
uint32 exponents[MAX_FACTORS];
uint32 p;
p = SmallFactors(r, factors, exponents);
ans = 1;
for(i = 0; i < p; i++) {
if(exponents[i] > 1) {
ans *= Power(factors[i], exponents[i] - 1);
}
ans *= factors[i] - 1;
}
return ans;
}
int IsPrimeAKS(uint64 n) {
uint64 i = 0;
uint32 r = 0;
uint64 t = 0;
uint32 logn = 0;
uint64 maxa = 0;
uint64 poly[MAXR];
if(IsPower(n)) return 0;
r = FindR(n);
for(i = 2; i <= r; i++) {
t = gcd(n, i);
if(t > 1 && t < n) return 0;
}
if(n <= r) return 1;
logn = BitCount(n);
maxa = ((uint64)logn) * SquareRoot(SmallPhi(r));
if(maxa >= n) maxa = n - 1;
for(i = 1; i <= maxa; i++) {
memset(poly, 0, sizeof(uint64) * r);
poly[0] = i; poly[1] = 1;
PowerPoly(poly, n, r);
if(!CheckPoly(poly, i, n, r)) return 0;
}
return 1;
}
int main()
{
uint64 n;
while(scanf("%"SCNu64, &n) != EOF) {
printf("%s\n", IsPrimeAKS(n) ? "Yes" : "No");
}
return 0;
}