?? zz_px.cpp
字號:
#include <NTL/ZZ_pX.h>
// The mul & sqr routines use routines from ZZX,
// which is faster for small degree polynomials.
// Define this macro to revert to old strategy.
#ifndef NTL_OLD_ZZ_pX_MUL
#include <NTL/ZZX.h>
#endif
#include <NTL/new.h>
#if (defined(NTL_GMP_LIP) || defined(NTL_GMP_HACK))
#define KARX 200
#else
#define KARX 80
#endif
NTL_START_IMPL
const ZZ_pX& ZZ_pX::zero()
{
static ZZ_pX z;
return z;
}
ZZ_pX& ZZ_pX::operator=(long a)
{
conv(*this, a);
return *this;
}
ZZ_pX& ZZ_pX::operator=(const ZZ_p& a)
{
conv(*this, a);
return *this;
}
istream& operator>>(istream& s, ZZ_pX& x)
{
s >> x.rep;
x.normalize();
return s;
}
ostream& operator<<(ostream& s, const ZZ_pX& a)
{
return s << a.rep;
}
void ZZ_pX::normalize()
{
long n;
const ZZ_p* p;
n = rep.length();
if (n == 0) return;
p = rep.elts() + n;
while (n > 0 && IsZero(*--p)) {
n--;
}
rep.SetLength(n);
}
long IsZero(const ZZ_pX& a)
{
return a.rep.length() == 0;
}
long IsOne(const ZZ_pX& a)
{
return a.rep.length() == 1 && IsOne(a.rep[0]);
}
void GetCoeff(ZZ_p& x, const ZZ_pX& a, long i)
{
if (i < 0 || i > deg(a))
clear(x);
else
x = a.rep[i];
}
void SetCoeff(ZZ_pX& x, long i, const ZZ_p& a)
{
long j, m;
if (i < 0)
Error("SetCoeff: negative index");
if (NTL_OVERFLOW(i, 1, 0))
Error("overflow in SetCoeff");
m = deg(x);
if (i > m) {
/* careful: a may alias a coefficient of x */
long alloc = x.rep.allocated();
if (alloc > 0 && i >= alloc) {
ZZ_pTemp aa_tmp; ZZ_p& aa = aa_tmp.val();
aa = a;
x.rep.SetLength(i+1);
x.rep[i] = aa;
}
else {
x.rep.SetLength(i+1);
x.rep[i] = a;
}
for (j = m+1; j < i; j++)
clear(x.rep[j]);
}
else
x.rep[i] = a;
x.normalize();
}
void SetCoeff(ZZ_pX& x, long i, long a)
{
if (a == 1)
SetCoeff(x, i);
else {
ZZ_pTemp TT; ZZ_p& T = TT.val();
conv(T, a);
SetCoeff(x, i, T);
}
}
void SetCoeff(ZZ_pX& x, long i)
{
long j, m;
if (i < 0)
Error("coefficient index out of range");
if (NTL_OVERFLOW(i, 1, 0))
Error("overflow in SetCoeff");
m = deg(x);
if (i > m) {
x.rep.SetLength(i+1);
for (j = m+1; j < i; j++)
clear(x.rep[j]);
}
set(x.rep[i]);
x.normalize();
}
void SetX(ZZ_pX& x)
{
clear(x);
SetCoeff(x, 1);
}
long IsX(const ZZ_pX& a)
{
return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
}
const ZZ_p& coeff(const ZZ_pX& a, long i)
{
if (i < 0 || i > deg(a))
return ZZ_p::zero();
else
return a.rep[i];
}
const ZZ_p& LeadCoeff(const ZZ_pX& a)
{
if (IsZero(a))
return ZZ_p::zero();
else
return a.rep[deg(a)];
}
const ZZ_p& ConstTerm(const ZZ_pX& a)
{
if (IsZero(a))
return ZZ_p::zero();
else
return a.rep[0];
}
void conv(ZZ_pX& x, const ZZ_p& a)
{
if (IsZero(a))
x.rep.SetLength(0);
else {
x.rep.SetLength(1);
x.rep[0] = a;
// note: if a aliases x.rep[i], i > 0, this code
// will still work, since is is assumed that
// SetLength(1) will not relocate or destroy x.rep[i]
}
}
void conv(ZZ_pX& x, long a)
{
if (a == 0)
clear(x);
else if (a == 1)
set(x);
else {
ZZ_pTemp TT; ZZ_p& T = TT.val();
conv(T, a);
conv(x, T);
}
}
void conv(ZZ_pX& x, const ZZ& a)
{
if (IsZero(a))
clear(x);
else {
ZZ_pTemp TT; ZZ_p& T = TT.val();
conv(T, a);
conv(x, T);
}
}
void conv(ZZ_pX& x, const vec_ZZ_p& a)
{
x.rep = a;
x.normalize();
}
void add(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
long da = deg(a);
long db = deg(b);
long minab = min(da, db);
long maxab = max(da, db);
x.rep.SetLength(maxab+1);
long i;
const ZZ_p *ap, *bp;
ZZ_p* xp;
for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
i; i--, ap++, bp++, xp++)
add(*xp, (*ap), (*bp));
if (da > minab && &x != &a)
for (i = da-minab; i; i--, xp++, ap++)
*xp = *ap;
else if (db > minab && &x != &b)
for (i = db-minab; i; i--, xp++, bp++)
*xp = *bp;
else
x.normalize();
}
void add(ZZ_pX& x, const ZZ_pX& a, const ZZ_p& b)
{
long n = a.rep.length();
if (n == 0) {
conv(x, b);
}
else if (&x == &a) {
add(x.rep[0], a.rep[0], b);
x.normalize();
}
else if (x.rep.MaxLength() == 0) {
x = a;
add(x.rep[0], a.rep[0], b);
x.normalize();
}
else {
// ugly...b could alias a coeff of x
ZZ_p *xp = x.rep.elts();
add(xp[0], a.rep[0], b);
x.rep.SetLength(n);
xp = x.rep.elts();
const ZZ_p *ap = a.rep.elts();
long i;
for (i = 1; i < n; i++)
xp[i] = ap[i];
x.normalize();
}
}
void add(ZZ_pX& x, const ZZ_pX& a, long b)
{
if (a.rep.length() == 0) {
conv(x, b);
}
else {
if (&x != &a) x = a;
add(x.rep[0], x.rep[0], b);
x.normalize();
}
}
void sub(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
long da = deg(a);
long db = deg(b);
long minab = min(da, db);
long maxab = max(da, db);
x.rep.SetLength(maxab+1);
long i;
const ZZ_p *ap, *bp;
ZZ_p* xp;
for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
i; i--, ap++, bp++, xp++)
sub(*xp, (*ap), (*bp));
if (da > minab && &x != &a)
for (i = da-minab; i; i--, xp++, ap++)
*xp = *ap;
else if (db > minab)
for (i = db-minab; i; i--, xp++, bp++)
negate(*xp, *bp);
else
x.normalize();
}
void sub(ZZ_pX& x, const ZZ_pX& a, const ZZ_p& b)
{
long n = a.rep.length();
if (n == 0) {
conv(x, b);
negate(x, x);
}
else if (&x == &a) {
sub(x.rep[0], a.rep[0], b);
x.normalize();
}
else if (x.rep.MaxLength() == 0) {
x = a;
sub(x.rep[0], a.rep[0], b);
x.normalize();
}
else {
// ugly...b could alias a coeff of x
ZZ_p *xp = x.rep.elts();
sub(xp[0], a.rep[0], b);
x.rep.SetLength(n);
xp = x.rep.elts();
const ZZ_p *ap = a.rep.elts();
long i;
for (i = 1; i < n; i++)
xp[i] = ap[i];
x.normalize();
}
}
void sub(ZZ_pX& x, const ZZ_pX& a, long b)
{
if (b == 0) {
x = a;
return;
}
if (a.rep.length() == 0) {
x.rep.SetLength(1);
x.rep[0] = b;
negate(x.rep[0], x.rep[0]);
}
else {
if (&x != &a) x = a;
sub(x.rep[0], x.rep[0], b);
}
x.normalize();
}
void sub(ZZ_pX& x, const ZZ_p& a, const ZZ_pX& b)
{
ZZ_pTemp TT; ZZ_p& T = TT.val();
T = a;
negate(x, b);
add(x, x, T);
}
void sub(ZZ_pX& x, long a, const ZZ_pX& b)
{
ZZ_pTemp TT; ZZ_p& T = TT.val();
T = a;
negate(x, b);
add(x, x, T);
}
void negate(ZZ_pX& x, const ZZ_pX& a)
{
long n = a.rep.length();
x.rep.SetLength(n);
const ZZ_p* ap = a.rep.elts();
ZZ_p* xp = x.rep.elts();
long i;
for (i = n; i; i--, ap++, xp++)
negate((*xp), (*ap));
}
#ifndef NTL_OLD_ZZ_pX_MUL
// These crossovers are tuned for a Pentium, but hopefully
// they should be OK on other machines as well.
const long SS_kbound = 40;
const double SS_rbound = 1.25;
void mul(ZZ_pX& c, const ZZ_pX& a, const ZZ_pX& b)
{
if (IsZero(a) || IsZero(b)) {
clear(c);
return;
}
if (&a == &b) {
sqr(c, a);
return;
}
long k = ZZ_p::ModulusSize();
long s = min(deg(a), deg(b)) + 1;
if (s == 1 || (k == 1 && s < 40) || (k == 2 && s < 20) ||
(k == 3 && s < 12) || (k <= 5 && s < 8) ||
(k <= 12 && s < 4) ) {
PlainMul(c, a, b);
}
else if (s < KARX) {
ZZX A, B, C;
conv(A, a);
conv(B, b);
KarMul(C, A, B);
conv(c, C);
}
else {
long mbits;
mbits = NumBits(ZZ_p::modulus());
if (k >= SS_kbound &&
SSRatio(deg(a), mbits, deg(b), mbits) < SS_rbound) {
ZZX A, B, C;
conv(A, a);
conv(B, b);
SSMul(C, A, B);
conv(c, C);
}
else {
FFTMul(c, a, b);
}
}
}
void sqr(ZZ_pX& c, const ZZ_pX& a)
{
if (IsZero(a)) {
clear(c);
return;
}
long k = ZZ_p::ModulusSize();
long s = deg(a) + 1;
if (s == 1 || (k == 1 && s < 50) || (k == 2 && s < 25) ||
(k == 3 && s < 25) || (k <= 6 && s < 12) ||
(k <= 8 && s < 8) || (k == 9 && s < 6) ||
(k <= 30 && s < 4) ) {
PlainSqr(c, a);
}
else if (s < 80) {
ZZX C, A;
conv(A, a);
KarSqr(C, A);
conv(c, C);
}
else {
long mbits;
mbits = NumBits(ZZ_p::modulus());
if (k >= SS_kbound &&
SSRatio(deg(a), mbits, deg(a), mbits) < SS_rbound) {
ZZX A, C;
conv(A, a);
SSSqr(C, A);
conv(c, C);
}
else {
FFTSqr(c, a);
}
}
}
#else
void mul(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
if (&a == &b) {
sqr(x, a);
return;
}
if (deg(a) > NTL_ZZ_pX_FFT_CROSSOVER && deg(b) > NTL_ZZ_pX_FFT_CROSSOVER)
FFTMul(x, a, b);
else
PlainMul(x, a, b);
}
void sqr(ZZ_pX& x, const ZZ_pX& a)
{
if (deg(a) > NTL_ZZ_pX_FFT_CROSSOVER)
FFTSqr(x, a);
else
PlainSqr(x, a);
}
#endif
void PlainMul(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
long da = deg(a);
long db = deg(b);
if (da < 0 || db < 0) {
clear(x);
return;
}
if (da == 0) {
mul(x, b, a.rep[0]);
return;
}
if (db == 0) {
mul(x, a, b.rep[0]);
return;
}
long d = da+db;
const ZZ_p *ap, *bp;
ZZ_p *xp;
ZZ_pX la, lb;
if (&x == &a) {
la = a;
ap = la.rep.elts();
}
else
ap = a.rep.elts();
if (&x == &b) {
lb = b;
bp = lb.rep.elts();
}
else
bp = b.rep.elts();
x.rep.SetLength(d+1);
xp = x.rep.elts();
long i, j, jmin, jmax;
static ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-db);
jmax = min(da, i);
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, rep(ap[j]), rep(bp[i-j]));
add(accum, accum, t);
}
conv(xp[i], accum);
}
x.normalize();
}
void PlainSqr(ZZ_pX& x, const ZZ_pX& a)
{
long da = deg(a);
if (da < 0) {
clear(x);
return;
}
long d = 2*da;
const ZZ_p *ap;
ZZ_p *xp;
ZZ_pX la;
if (&x == &a) {
la = a;
ap = la.rep.elts();
}
else
ap = a.rep.elts();
x.rep.SetLength(d+1);
xp = x.rep.elts();
long i, j, jmin, jmax;
long m, m2;
static ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-da);
jmax = min(da, i);
m = jmax - jmin + 1;
m2 = m >> 1;
jmax = jmin + m2 - 1;
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, rep(ap[j]), rep(ap[i-j]));
add(accum, accum, t);
}
add(accum, accum, accum);
if (m & 1) {
sqr(t, rep(ap[jmax + 1]));
add(accum, accum, t);
}
conv(xp[i], accum);
}
x.normalize();
}
void PlainDivRem(ZZ_pX& q, ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b)
{
long da, db, dq, i, j, LCIsOne;
const ZZ_p *bp;
ZZ_p *qp;
ZZ *xp;
ZZ_p LCInv, t;
static ZZ s;
da = deg(a);
db = deg(b);
if (db < 0) Error("ZZ_pX: division by zero");
if (da < db) {
r = a;
clear(q);
return;
}
ZZ_pX lb;
if (&q == &b) {
lb = b;
bp = lb.rep.elts();
}
else
bp = b.rep.elts();
if (IsOne(bp[db]))
LCIsOne = 1;
else {
LCIsOne = 0;
inv(LCInv, bp[db]);
}
ZZVec x(da + 1, ZZ_pInfo->ExtendedModulusSize);
for (i = 0; i <= da; i++)
x[i] = rep(a.rep[i]);
xp = x.elts();
dq = da - db;
q.rep.SetLength(dq+1);
qp = q.rep.elts();
for (i = dq; i >= 0; i--) {
conv(t, xp[i+db]);
if (!LCIsOne)
mul(t, t, LCInv);
qp[i] = t;
negate(t, t);
for (j = db-1; j >= 0; j--) {
mul(s, rep(t), rep(bp[j]));
add(xp[i+j], xp[i+j], s);
}
}
r.rep.SetLength(db);
for (i = 0; i < db; i++)
conv(r.rep[i], xp[i]);
r.normalize();
}
void PlainRem(ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b, ZZVec& x)
{
long da, db, dq, i, j, LCIsOne;
const ZZ_p *bp;
ZZ *xp;
ZZ_p LCInv, t;
static ZZ s;
da = deg(a);
db = deg(b);
if (db < 0) Error("ZZ_pX: division by zero");
if (da < db) {
r = a;
return;
}
bp = b.rep.elts();
if (IsOne(bp[db]))
LCIsOne = 1;
else {
LCIsOne = 0;
inv(LCInv, bp[db]);
}
for (i = 0; i <= da; i++)
x[i] = rep(a.rep[i]);
xp = x.elts();
dq = da - db;
for (i = dq; i >= 0; i--) {
conv(t, xp[i+db]);
if (!LCIsOne)
mul(t, t, LCInv);
negate(t, t);
for (j = db-1; j >= 0; j--) {
mul(s, rep(t), rep(bp[j]));
add(xp[i+j], xp[i+j], s);
}
}
r.rep.SetLength(db);
for (i = 0; i < db; i++)
conv(r.rep[i], xp[i]);
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -