AKS素性检测算法的C语言实现

按道理说这个算法提出之后,应该是有很多语言的实现版本。但是对于多项式的方幂运算,C语言版本总感觉缺少点什么。以前在其他地方见过,但多项式方幂的实现算法复杂度较高。因此自己也实现一个版本。但目前遇到一个问题,就是64bit的两个整数相乘,会溢出。虽然是在模n(也是64bit的)下运算的,但为了避免溢出,好好的乘法非要改用加法实现。希望在这方面可以有改进的地方( Mul64Mod )。

备注:2022年11月1日修订。根据文献[3]的提示,GNU G++17可以使用 unsigned __int128 数据类型,虽然我的编译器( g++.exe (MinGW-W64 x86_64-ucrt-posix-seh, built by Brecht Sanders) 12.2.0 )提示使用的数据类型应该为 __uint128_t 。但起码目前模乘法的速度问题算是基本解决了。另外,我将当前版本保存了一份在网上:https://onlinegdb.com/dFtGoQgL6 或者 https://gist.github.com/tpu01yzx/172c65d3003bb3a09a941e69bc6c370b

#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
#include <memory.h>
#include <math.h>

#define N 100
#define MAX_FACTORS 32
#define MAXR 320

typedef unsigned long long int uint64;
typedef unsigned int uint32;
typedef unsigned char uint8;

uint64 gcd(uint64 a, uint64 b) {
    uint64 c = a;
    while(b) {
        c = a % b;
        a = b;
        b = c;
    }
    return a;
}

uint8 BitCount(uint64 n) {
    uint8 ans = 0;
    while(n) {
        n>>=1;
        ans++;
    }
    return ans;
}

uint32 SquareRoot(uint64 x) {
    if(x < (1ULL<<32)) {
        return (uint32)sqrt((uint32)x);
    }
    uint8 logx = BitCount(x);    
    uint32 l = pow(2.0, (double)(logx-1) / 2);
    uint32 r = pow(2.0, (double)(logx) / 2);
    uint32 m = l;
    uint64 m2 = x;

    while(l <= r) {        
        m = (l + r ) / 2;        
        m2 = m * m;
        
        if(m2 == x) return m;
        if(m2 < x) {
            l = m + 1;
        } else {
            r = m - 1;
        }
    }
    return m;
}

uint64 Power(uint32 a, uint8 k) {
    if(k == 0) return 1;

    uint64 ans = 1;
    uint64 a2 = a;
    while(k > 1) {
        if(k & 0x01) {
            ans *= a2;
        }
        a2 = a2 * a2;
        k>>=1;
    }
    ans *= a2;
    return ans;
}

uint32 PowerMod(uint32 a, uint8 k, uint32 mod) {
    if(k == 0) return 1;

    uint64 ans = 1;
    uint64 a2 = a % mod;    
    while(k > 1) {
        if(k & 0x01) {
            ans = (ans * a2) % mod;            
        }
        a2 = (a2 * a2) % mod;        
        k>>=1;
    }
    ans = (ans * a2) % mod; 
    return (uint32)ans;
}

uint64 Mul64Mod(uint64 a, uint64 b, uint64 mod) {
    return (__uint128_t)a * b % mod;
}

void MulPoly(uint64 *p1, uint64 *p2, uint64 n, uint32 r) {
    uint64 ans[MAXR];
    int i, j;

    memset(ans, 0, sizeof(uint64) * r);
    for(i = 0; i < r; i++) {
        for(j = 0; j <= i; j++) {
            ans[i] += Mul64Mod(p1[j], p2[i - j], n);
            if(ans[i] >= n) ans[i] -= n;
        }
        for(j = i + 1; j < r; j++) {
            ans[i] += Mul64Mod(p1[j], p2[r + i - j], n);
            if(ans[i] >= n) ans[i] -= n;
        }        
    }    
    memcpy(p1, ans, sizeof(uint64) * r);
}

void PowerPoly(uint64 *coff, uint64 n, uint32 r) {
    if(n == 0) {
        memset(coff, 0, sizeof(uint64) * r);
        coff[0] = 1ULL;
        return;
    }
    uint64 n0 = n;
    uint64 ans[MAXR];
    uint64 coff2[MAXR];
    memset(ans, 0, sizeof(uint64) * r);    ans[0] = 1ULL; 
    memcpy(coff2, coff, sizeof(uint64) * r);

    while(n > 1) {
        if(n & 0x01) {
            MulPoly(ans, coff2, n0, r);
        }
        MulPoly(coff2, coff2, n0, r);
        n>>=1;
    }
    MulPoly(ans, coff2, n0, r);
    
    memcpy(coff, ans, sizeof(uint64) * r);
}

int CheckPoly(uint64 *p, uint64 a, uint64 n, uint32 r) {
    uint64 n0 = n % r;
    uint32 i;
    if(p[0] != a) return 0;    
    for(i = 1; i < n0; i++) {
        if(p[i] != 0) return 0;
    }
    if(p[n0] != 1ULL) return 0;
    for(i = n0 + 1; i < r; i++) {
        if(p[i] != 0) return 0;
    }
    return 1;
}

uint32 PerfectRoot(uint8 a, uint64 n, uint8 logn) {    
    if(a == 1) return n;
    uint32 l = pow(2.0, (double)(logn-1) / a);;
    uint32 r = pow(2.0, ((double)(logn) / a));
    uint32 m;
    uint64 mp;
    while(l <= r) {
        m = (l + r) / 2;
        mp = Power(m, a);
        if(mp == n) return m;
        if(mp < n) {
            l = m + 1;
        } else {
            r = m - 1;
        }
    }
    return 0;
}

int IsPower(uint64 n) {
    uint8 i, j;
    uint8 cnt = BitCount(n);
    for(i = 2; i < cnt; i++) {
        if(PerfectRoot(i, n, cnt)) return 1;
    }
    return 0;
}

uint8 SmallFactors(uint32 r, uint32 *factors, uint32 *exponents) {
  uint32 i;
  uint32 sqrtr = SquareRoot(r);
  uint8 p = 0;
  for(i = 2; i <= sqrtr; i++) {
      if(r % i == 0) {
          factors[p] = i;
          exponents[p] = 0;
          while(r % i == 0) {
              r /= i;
              exponents[p]++;
          }
          p++;
          sqrtr = SquareRoot(r);
      }
  }
  if(r > 1) {
      factors[p] = r;
      exponents[p] = 1;
      p++;
  }
  return p;
}

uint32 SmallOrder(uint32 n, uint32 r) {
  uint32 i;
  uint32 factors[MAX_FACTORS];
  uint32 exponents[MAX_FACTORS];
  uint32 p;    
  uint32 ans;
  
  ans = 1;
  p = SmallFactors(r, factors, exponents);
  for(i = 0; i < p; i++) {
    if(exponents[i] > 1) {
      ans *= Power(factors[i], exponents[i] - 1);    
    }
    ans *= factors[i] - 1;
  }
  
  p = SmallFactors(ans, factors, exponents);
  for(i = 0; i < p; i++) {
    while(ans % factors[i] == 0) {
      if(PowerMod(n, ans, r) == 1) {
        ans /= factors[i];
      } else {
        break;
      }
    }
    if(ans % factors[i] == 0) {
      ans *= factors[i];
    }
  }
  return ans;
}


uint32 FindR(uint64 n) {
    uint32 r, k;    
    uint8 logn = BitCount(n);
    uint32 maxr = Power(logn, 5);
    uint32 maxk = Power(logn, 2);
    if(maxr < 3) maxr = 3;    
    for(r = 2; r <= maxr; r++) {
        if(gcd(n, r) > 1) continue;   
        k = SmallOrder(n % r, r);
        if(k > maxk) break;
    }
    return r;
}


uint32 SmallPhi(uint32 r) {
    uint32 ans;
    uint32 i;    
    uint32 factors[MAX_FACTORS];
    uint32 exponents[MAX_FACTORS];
    uint32 p;    
    p = SmallFactors(r, factors, exponents);
    ans = 1;
    for(i = 0; i < p; i++) {
        if(exponents[i] > 1) {
          ans *= Power(factors[i], exponents[i] - 1);    
        }
        ans *= factors[i] - 1;
    }
    return ans;
}

int IsPrimeAKS(uint64 n) {
    uint64 i = 0;
    uint32 r = 0;
    uint64 t = 0;
    uint32 logn = 0;
    uint64 maxa = 0;
    uint64 poly[MAXR];

    if(IsPower(n)) return 0;    

    r = FindR(n);

    for(i = 2; i <= r; i++) {
        t = gcd(n, i);
        if(t > 1 && t < n) return 0;
    }

    if(n <= r) return 1;

    logn = BitCount(n);
    maxa = ((uint64)logn) * SquareRoot(SmallPhi(r));
    if(maxa >= n) maxa = n - 1;


    for(i = 1; i <= maxa; i++) {      
        memset(poly, 0, sizeof(uint64) * r);
        poly[0] = i; poly[1] = 1;                
        PowerPoly(poly, n, r);
        if(!CheckPoly(poly, i, n, r)) return 0;
    }

    return 1;
}


int main()
{  
    uint64 n;
    while(scanf("%"SCNu64, &n) != EOF) {
        printf("%s\n", IsPrimeAKS(n) ? "Yes" : "No");
    }
    return 0;
}


参考文献:

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注