#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
#include <memory.h>
#include <math.h>
#define N 100
#define MAX_FACTORS 32
#define MAXR 320
typedef unsigned long long int uint64;
typedef unsigned int uint32;
typedef unsigned char uint8;
uint64 gcd(uint64 a, uint64 b) {
uint64 c = a;
while(b) {
c = a % b;
a = b;
b = c;
}
return a;
}
uint8 BitCount(uint64 n) {
uint8 ans = 0;
while(n) {
n>>=1;
ans++;
}
return ans;
}
uint32 SquareRoot(uint64 x) {
if(x < (1ULL<<32)) {
return (uint32)sqrt((uint32)x);
}
uint8 logx = BitCount(x);
uint32 l = pow(2.0, (double)(logx-1) / 2);
uint32 r = pow(2.0, (double)(logx) / 2);
uint32 m = l;
uint64 m2 = x;
while(l <= r) {
m = (l + r ) / 2;
m2 = m * m;
if(m2 == x) return m;
if(m2 < x) {
l = m + 1;
} else {
r = m - 1;
}
}
return m;
}
uint64 Power(uint32 a, uint8 k) {
if(k == 0) return 1;
uint64 ans = 1;
uint64 a2 = a;
while(k > 1) {
if(k & 0x01) {
ans *= a2;
}
a2 = a2 * a2;
k>>=1;
}
ans *= a2;
return ans;
}
uint32 PowerMod(uint32 a, uint8 k, uint32 mod) {
if(k == 0) return 1;
uint64 ans = 1;
uint64 a2 = a % mod;
while(k > 1) {
if(k & 0x01) {
ans = (ans * a2) % mod;
}
a2 = (a2 * a2) % mod;
k>>=1;
}
ans = (ans * a2) % mod;
return (uint32)ans;
}
uint64 Mul64Mod(uint64 a, uint64 b, uint64 mod) {
return (__uint128_t)a * b % mod;
}
void MulPoly(uint64 *p1, uint64 *p2, uint64 n, uint32 r) {
uint64 ans[MAXR];
int i, j;
memset(ans, 0, sizeof(uint64) * r);
for(i = 0; i < r; i++) {
for(j = 0; j <= i; j++) {
ans[i] += Mul64Mod(p1[j], p2[i - j], n);
if(ans[i] >= n) ans[i] -= n;
}
for(j = i + 1; j < r; j++) {
ans[i] += Mul64Mod(p1[j], p2[r + i - j], n);
if(ans[i] >= n) ans[i] -= n;
}
}
memcpy(p1, ans, sizeof(uint64) * r);
}
void PowerPoly(uint64 *coff, uint64 n, uint32 r) {
if(n == 0) {
memset(coff, 0, sizeof(uint64) * r);
coff[0] = 1ULL;
return;
}
uint64 n0 = n;
uint64 ans[MAXR];
uint64 coff2[MAXR];
memset(ans, 0, sizeof(uint64) * r); ans[0] = 1ULL;
memcpy(coff2, coff, sizeof(uint64) * r);
while(n > 1) {
if(n & 0x01) {
MulPoly(ans, coff2, n0, r);
}
MulPoly(coff2, coff2, n0, r);
n>>=1;
}
MulPoly(ans, coff2, n0, r);
memcpy(coff, ans, sizeof(uint64) * r);
}
int CheckPoly(uint64 *p, uint64 a, uint64 n, uint32 r) {
uint64 n0 = n % r;
uint32 i;
if(p[0] != a) return 0;
for(i = 1; i < n0; i++) {
if(p[i] != 0) return 0;
}
if(p[n0] != 1ULL) return 0;
for(i = n0 + 1; i < r; i++) {
if(p[i] != 0) return 0;
}
return 1;
}
uint32 PerfectRoot(uint8 a, uint64 n, uint8 logn) {
if(a == 1) return n;
uint32 l = pow(2.0, (double)(logn-1) / a);;
uint32 r = pow(2.0, ((double)(logn) / a));
uint32 m;
uint64 mp;
while(l <= r) {
m = (l + r) / 2;
mp = Power(m, a);
if(mp == n) return m;
if(mp < n) {
l = m + 1;
} else {
r = m - 1;
}
}
return 0;
}
int IsPower(uint64 n) {
uint8 i, j;
uint8 cnt = BitCount(n);
for(i = 2; i < cnt; i++) {
if(PerfectRoot(i, n, cnt)) return 1;
}
return 0;
}
uint8 SmallFactors(uint32 r, uint32 *factors, uint32 *exponents) {
uint32 i;
uint32 sqrtr = SquareRoot(r);
uint8 p = 0;
for(i = 2; i <= sqrtr; i++) {
if(r % i == 0) {
factors[p] = i;
exponents[p] = 0;
while(r % i == 0) {
r /= i;
exponents[p]++;
}
p++;
sqrtr = SquareRoot(r);
}
}
if(r > 1) {
factors[p] = r;
exponents[p] = 1;
p++;
}
return p;
}
uint32 SmallOrder(uint32 n, uint32 r) {
uint32 i;
uint32 factors[MAX_FACTORS];
uint32 exponents[MAX_FACTORS];
uint32 p;
uint32 ans;
ans = 1;
p = SmallFactors(r, factors, exponents);
for(i = 0; i < p; i++) {
if(exponents[i] > 1) {
ans *= Power(factors[i], exponents[i] - 1);
}
ans *= factors[i] - 1;
}
p = SmallFactors(ans, factors, exponents);
for(i = 0; i < p; i++) {
while(ans % factors[i] == 0) {
if(PowerMod(n, ans, r) == 1) {
ans /= factors[i];
} else {
break;
}
}
if(ans % factors[i] == 0) {
ans *= factors[i];
}
}
return ans;
}
uint32 FindR(uint64 n) {
uint32 r, k;
uint8 logn = BitCount(n);
uint32 maxr = Power(logn, 5);
uint32 maxk = Power(logn, 2);
if(maxr < 3) maxr = 3;
for(r = 2; r <= maxr; r++) {
if(gcd(n, r) > 1) continue;
k = SmallOrder(n % r, r);
if(k > maxk) break;
}
return r;
}
uint32 SmallPhi(uint32 r) {
uint32 ans;
uint32 i;
uint32 factors[MAX_FACTORS];
uint32 exponents[MAX_FACTORS];
uint32 p;
p = SmallFactors(r, factors, exponents);
ans = 1;
for(i = 0; i < p; i++) {
if(exponents[i] > 1) {
ans *= Power(factors[i], exponents[i] - 1);
}
ans *= factors[i] - 1;
}
return ans;
}
int IsPrimeAKS(uint64 n) {
uint64 i = 0;
uint32 r = 0;
uint64 t = 0;
uint32 logn = 0;
uint64 maxa = 0;
uint64 poly[MAXR];
if(IsPower(n)) return 0;
r = FindR(n);
for(i = 2; i <= r; i++) {
t = gcd(n, i);
if(t > 1 && t < n) return 0;
}
if(n <= r) return 1;
logn = BitCount(n);
maxa = ((uint64)logn) * SquareRoot(SmallPhi(r));
if(maxa >= n) maxa = n - 1;
for(i = 1; i <= maxa; i++) {
memset(poly, 0, sizeof(uint64) * r);
poly[0] = i; poly[1] = 1;
PowerPoly(poly, n, r);
if(!CheckPoly(poly, i, n, r)) return 0;
}
return 1;
}
int main()
{
uint64 n;
while(scanf("%"SCNu64, &n) != EOF) {
printf("%s\n", IsPrimeAKS(n) ? "Yes" : "No");
}
return 0;
}
#include <stdio.h>
#include <stdint.h>
#include <memory.h>
#include <math.h>
const int DEBUG_LEVEL_INFO = 1;
const int DEBUG_LEVEL_VALUE = 2;
const int DEBUG_LEVEL_CHECK = 3;
const int debug_level = 0;
#define N 100
#define MAX_FACTORS 32
#define MAXR 320
typedef unsigned long long int uint64;
typedef unsigned int uint32;
typedef unsigned char uint8;
uint64 gcd(uint64 a, uint64 b) {
uint64 c = a;
while(b) {
c = a % b;
a = b;
b = c;
}
return a;
}
//return Roundup(log_2{n}) if n is greater than zero, otherwise return 0;
//
uint8 BitCount(uint64 n) {
uint8 ans = 0;
while(n) {
n>>=1;
ans++;
}
return ans;
}
uint32 SquareRoot(uint64 x) {
if(x < (1ULL<<32)) {
return (uint32)sqrt((uint32)x);
}
//uint32 l = pow(2.0, (double)(logn-1) / a);;
//uint32 r = pow(2.0, ((double)(logn) / a));
uint8 logx = BitCount(x);
uint32 l = pow(2.0, (double)(logx-1) / 2);
uint32 r = pow(2.0, (double)(logx) / 2);
uint32 m = l;
uint64 m2 = x;
while(l <= r) {
m = (l + r ) / 2;
m2 = m * m;
//if(debug_level >= DEBUG_LEVEL_VALUE) {
// printf("%u %u %u %u %u.\n", l, r, m, m2, x);
//}
if(m2 == x) return m;
if(m2 < x) {
l = m + 1;
} else {//m2 > x
r = m - 1;
}
}
return m;
}
//it should be make sure that a^k < 2^64
uint64 Power(uint32 a, uint8 k) {
if(k == 0) return 1;
uint64 ans = 1;
uint64 a2 = a;
while(k > 1) {
if(k & 0x01) {
//might overflow ?
ans *= a2;
}
//might overflow ?
a2 = a2 * a2;
k>>=1;
}
ans *= a2;
return ans;
}
uint32 PowerMod(uint32 a, uint8 k, uint32 mod) {
if(k == 0) return 1;
uint64 ans = 1;
uint64 a2 = a % mod;
while(k > 0) {
if(k & 0x01) {
ans = (ans * a2) % mod;
}
a2 = (a2 * a2) % mod;
k>>=1;
}
ans = (ans * a2) % mod;
return (uint32)ans;
}
void PrintPoly(uint64 *p, uint32 r) {
int i;
for(i = 0; i < r; i++) {
printf("%llu ", p[i]);
}
printf("\n");
}
//make sure that a and b are less thant mod
uint64 Add64Mod(uint64 a, uint64 b, uint64 mod) {
uint64 c = a + b;
if(c < a) { //overflow: mod > 2^63
return c + (-mod);
}
if(c >= mod) c-= mod;
return c;
}
//compute (2*a) % mod
//make sure that a is less thant mod
uint64 Double64Mod(uint64 a, uint64 mod) {
//a >= 2^63
if(a>>63) {//mod > 2^63
//a = 2^63 + (a-2^63)
//a = (((-1ULL) - mod) + 1) + (a & ~(1ULL<<63));
return (a<<1)+(-mod);
} else {
a <<= 1;
if(a >= mod) a -= mod;
}
return a;
}
//compute (a * b) % mod
//make sure that a and b are less thant mod
uint64 Mul64Mod(uint64 a, uint64 b, uint64 mod) {
if(b == 0) return 0;
return (__uint128_t)a * b % mod;
uint64 a2 = a;
uint64 ans = 0;
while(b > 1) {
if(b & 0x01) {
ans = Add64Mod(ans, a2, mod);
}
a2 = Double64Mod(a2, mod);
//a2 = Add64Mod(a2, a2, mod);
//b2 = b2 + b2;
b>>=1;
}
ans = Add64Mod(ans, a2, mod);
return ans;
}
//compute (a * b) % mod
//make sure that a and b are less thant mod
// uint64 Mul64Mod (uint64 a, uint64 b, uint64 mod)
// {
// uint64 a_lo = (uint64)(uint32)a;
// uint64 a_hi = a >> 32;
// uint64 b_lo = (uint64)(uint32)b;
// uint64 b_hi = b >> 32;
// uint64 p0 = (a_lo * b_lo) % mod;
// uint64 p1 = (a_lo * b_hi) % mod;
// uint64 p2 = (a_hi * b_lo) % mod;
// uint64 p3 = (a_hi * b_hi) % mod;
// uint64 ans = p0;
// ans += ((p1 + p2) % mod * (1ULL<<32) % mod) % mod;
// if(ans >= mod) ans -= mod;
// ans += (((-mod) % mod) % mod ) * p3;
// if(ans >= mod) ans -= mod;
// return ans;
// }
//compute p1 *= p2 (mod x^r-1, n);
void MulPoly(uint64 *p1, uint64 *p2, uint64 n, uint32 r) {
uint64 ans[MAXR];
int i, j;
if(debug_level >= DEBUG_LEVEL_VALUE) {
printf("p1:"); PrintPoly(p1, r);
printf("p2:"); PrintPoly(p2, r);
}
memset(ans, 0, sizeof(uint64) * r);//ans = 0;
for(i = 0; i < r; i++) {
for(j = 0; j <= i; j++) {
//ans[i] += (p1[j] * p2[i - j]) % n;
ans[i] += Mul64Mod(p1[j], p2[i - j], n);
if(ans[i] >= n) ans[i] -= n;
}
for(j = i + 1; j < r; j++) {
//ans[i] += (p1[j] * p2[r + i - j]) % n;
ans[i] += Mul64Mod(p1[j], p2[r + i - j], n);
if(ans[i] >= n) ans[i] -= n;
}
//if(debug_level >= DEBUG_LEVEL_VALUE) {
// printf("MulPoly %d:", i); PrintPoly(ans, r);
//}
}
memcpy(p1, ans, sizeof(uint64) * r);
}
//compute p = p^n (mod x^r-1, n);
void PowerPoly(uint64 *coff, uint64 n, uint32 r) {
uint64 n0 = n;
uint64 ans[MAXR];
uint64 coff2[MAXR];
memset(ans, 0, sizeof(uint64) * r); ans[0] = 1ULL; //ans = 1
memcpy(coff2, coff, sizeof(uint64) * r);//coff2 = coff
if(debug_level >= DEBUG_LEVEL_VALUE) {
printf("ans:");PrintPoly(ans, r);
printf("coff2:");PrintPoly(coff2, r);
}
while(n) {
if(n & 0x01) {
MulPoly(ans, coff2, n0, r);
if(debug_level >= DEBUG_LEVEL_VALUE) {
printf("ans:");PrintPoly(ans, r);
}
}
MulPoly(coff2, coff2, n0, r);
if(debug_level >= DEBUG_LEVEL_VALUE) {
printf("coff2:");PrintPoly(coff2, r);
}
n>>=1;
}
memcpy(coff, ans, sizeof(uint64) * r);
}
//check if p == x^n + a (mod x^r-1,n);
//make sure that n % r != 0
int CheckPoly(uint64 *p, uint64 a, uint64 n, uint32 r) {
uint64 n0 = n % r;
uint32 i;
if(p[0] != a) return 0;
for(i = 1; i < n0; i++) {
if(p[i] != 0) return 0;
}
if(p[n0] != 1ULL) return 0;
for(i = n0 + 1; i < r; i++) {
if(p[i] != 0) return 0;
}
return 1;
}
//find x^a = n, if not found, return 0.
//n should not be zero.
uint32 PerfectRoot(uint8 a, uint64 n, uint8 logn) {
if(a == 1) return n;
uint32 l = pow(2.0, (double)(logn-1) / a);;
uint32 r = pow(2.0, ((double)(logn) / a));
uint32 m;
uint64 mp;
while(l <= r) {
m = (l + r) / 2;
mp = Power(m, a);
if(mp == n) return m;
if(mp < n) {
l = m + 1;
} else {
r = m - 1;
}
}
return 0;//not PerfectRoot
}
int IsPower(uint64 n) {
uint8 i, j;
uint8 cnt = BitCount(n);
for(i = 2; i < cnt; i++) {
if(PerfectRoot(i, n, cnt)) return 1;
}
return 0;
}
uint8 SmallFactors(uint32 r, uint32 *factors, uint32 *exponents) {
uint32 i;
uint32 sqrtr = SquareRoot(r);
uint8 p = 0;
for(i = 2; i <= sqrtr; i++) {
if(r % i == 0) {
factors[p] = i;
exponents[p] = 0;
while(r % i == 0) {
r /= i;
exponents[p]++;
}
p++;
sqrtr = SquareRoot(r);
}
}
if(r > 1) {
factors[p] = r;
exponents[p] = 1;
p++;
}
return p;
}
uint32 SmallOrder(uint32 n, uint32 r) {
uint32 i;
uint32 factors[MAX_FACTORS];
uint32 exponents[MAX_FACTORS];
uint32 p;
uint32 ans;
//factor for compute Euler function.
ans = 1;
p = SmallFactors(r, factors, exponents);
for(i = 0; i < p; i++) {
if(exponents[i] > 1) {
ans *= Power(factors[i], exponents[i] - 1);
}
ans *= factors[i] - 1;
}
//factor for compute the possible order.
p = SmallFactors(ans, factors, exponents);
for(i = 0; i < p; i++) {
while(ans % factors[i] == 0) {
if(PowerMod(n, ans, r) == 1) {
ans /= factors[i];
} else {
break;
}
}
if(ans % factors[i] == 0) {
ans *= factors[i];
}
}
return ans;
}
//Lemma 4.3. There exist an r = max{3, (RoundUp(log_2{n}))^5} such that O_r(n) > (log_2{n})^2
////n should be grater than 2;
uint32 FindR(uint64 n) {
uint32 r, k;
uint8 logn = BitCount(n);
uint32 maxr = Power(logn, 5);
uint32 maxk = Power(logn, 2);
if(maxr < 3) maxr = 3;
for(r = 2; r <= maxr; r++) {
if(gcd(n, r) > 1) continue;
//compute O_r{n} by testing every possible value until the power is 1, i.e., the result is found.
k = SmallOrder(n % r, r);
if(k > maxk) break;
}
return r;
}
uint32 SmallPhi(uint32 r) {
uint32 ans;
uint32 i;
uint32 factors[MAX_FACTORS];
uint32 exponents[MAX_FACTORS];
uint32 p;
p = SmallFactors(r, factors, exponents);
ans = 1;
for(i = 0; i < p; i++) {
if(exponents[i] > 1) {
ans *= Power(factors[i], exponents[i] - 1);
}
ans *= factors[i] - 1;
}
return ans;
}
int IsPrimeAKS(uint64 n) {
uint64 i = 0;
uint32 r = 0;
uint64 t = 0;
uint32 logn = 0;
uint64 maxa = 0;
uint64 poly[MAXR];
if(debug_level >= DEBUG_LEVEL_INFO) {
printf("Step 1: CHeck if it is a power with exponent greater than 1.\n");
}
if(IsPower(n)) return 0;
if(debug_level >= DEBUG_LEVEL_INFO) {
printf("Step 2: Find the smallest r such that O_r(n) < (log_2{n})^5.\n");
}
r = FindR(n);
if(debug_level >= DEBUG_LEVEL_VALUE) {
printf("r=%u\n", r);
}
if(r > MAXR) {
printf("ERROR: r is greater than MAXR (%d)\n", MAXR);
return -1;
}
if(debug_level >= DEBUG_LEVEL_INFO) {
printf("Step 3: Test for the factors not greater than r.\n");
}
for(i = 2; i <= r; i++) {
t = gcd(n, i);
if(t > 1 && t < n) return 0;
}
if(debug_level >= DEBUG_LEVEL_INFO) {
printf("Step 4: Return PRIME if n is small and pass all the small factor tests.\n");
}
if(n <= r) return 1;
if(debug_level >= DEBUG_LEVEL_INFO) {
printf("Step 5: Check many equations.\n");
}
logn = BitCount(n);
//overflow ?
maxa = ((uint64)logn) * SquareRoot(SmallPhi(r));
if(maxa >= n) maxa = n - 1;
//check if p == x^n + a (mod x^r-1,n);
//int CheckPoly(uint64 *p, uint64 a, uint64 n, uint32 r) {
//compute p = p^n (mod x^r-1, n);
//void PowerPoly(uint64 *coff, uint64 n, uint32 r)
for(i = 1; i <= maxa; i++) {
if(debug_level >= DEBUG_LEVEL_INFO) {
printf("\tTest (X+%llu)^(%llu) = X^(%llu)+%llu\n", i, n, n, i);
}
memset(poly, 0, sizeof(uint64) * r);
poly[0] = i; poly[1] = 1;
PowerPoly(poly, n, r);
if(debug_level >= DEBUG_LEVEL_VALUE) {
PrintPoly(poly, r);
}
if(!CheckPoly(poly, i, n, r)) return 0;
}
if(debug_level >= DEBUG_LEVEL_INFO) {
printf("Step 6: Return PRIME if pass all the equation tests.\n");
}
return 1;
}
int IsPrimeBruteForce(uint64 n) {
if(n == 2) return 1;
uint32 sqrtn = SquareRoot(n);
uint32 i;
for(i = 2; i<= sqrtn; i++) {
if(n % i == 0) return 0;
}
return 1;
}
void TestPrime(uint64 n) {
int rAKS, rBruteForce;
rBruteForce = IsPrimeBruteForce(n);
rAKS = IsPrimeAKS(n);
if(rBruteForce != rAKS) {
printf("n=%llu, IsPrimeBruteForce return: %s, IsPrimeAKS return: %s \n", n,
rBruteForce ? "Yes":"No", rAKS ? "Yes" : "No");
}
}
int main()
{
if(debug_level >= DEBUG_LEVEL_VALUE) {
TestPrime(5000000051ULL);
return 0;
}
uint64 i;
uint64 maxi = 1000;
//uint64 ns[] = {3000000019ULL, 3000000015ULL, 5000000029ULL, 5000000039ULL, 5000000051ULL, 29000000000000047ULL, 29000000000000045ULL};
uint64 ns[] = {5000000029ULL, 29000000000000047ULL};
uint32 ns_cnt = sizeof(ns) / sizeof(uint64);
printf("1. Begin Small Test...\n");
for(i = 2; i <= maxi; i++) {
TestPrime(i);
}
printf("2. Begin Big Test...\n");
for(i = 0; i < ns_cnt; i++) {
TestPrime(ns[i]);
}
printf("3. Test is over.\n");
return 0;
}