From 74554a206c323d35b823bb469194cbf99eb00b5b Mon Sep 17 00:00:00 2001 From: jules Date: Tue, 31 May 2016 11:47:04 +0100 Subject: [PATCH] Improved performance of some BigInteger methods by adding Montgomery Multiplication and extended Euclidan algorithms --- .../Demo/Source/Demos/CryptographyDemo.cpp | 2 +- modules/juce_core/maths/juce_BigInteger.cpp | 177 +++++++++++++++--- modules/juce_core/maths/juce_BigInteger.h | 42 +++-- 3 files changed, 178 insertions(+), 43 deletions(-) diff --git a/examples/Demo/Source/Demos/CryptographyDemo.cpp b/examples/Demo/Source/Demos/CryptographyDemo.cpp index 90598407d7..3c7050c8d0 100644 --- a/examples/Demo/Source/Demos/CryptographyDemo.cpp +++ b/examples/Demo/Source/Demos/CryptographyDemo.cpp @@ -70,7 +70,7 @@ public: private: void createRSAKey() { - int bits = jlimit (32, 512, bitSize.getText().getIntValue()); + int bits = jlimit (32, 1024, bitSize.getText().getIntValue()); bitSize.setText (String (bits), dontSendNotification); // Create a key-pair... diff --git a/modules/juce_core/maths/juce_BigInteger.cpp b/modules/juce_core/maths/juce_BigInteger.cpp index 81bcf61beb..8b0d43728c 100644 --- a/modules/juce_core/maths/juce_BigInteger.cpp +++ b/modules/juce_core/maths/juce_BigInteger.cpp @@ -385,12 +385,12 @@ BigInteger& BigInteger::operator+= (const BigInteger& other) BigInteger temp (*this); temp.negate(); *this = other; - operator-= (temp); + *this -= temp; } else { negate(); - operator-= (other); + *this -= other; negate(); } } @@ -436,7 +436,7 @@ BigInteger& BigInteger::operator-= (const BigInteger& other) { BigInteger temp (other); swapWith (temp); - operator-= (temp); + *this -= temp; negate(); return *this; } @@ -444,7 +444,7 @@ BigInteger& BigInteger::operator-= (const BigInteger& other) else { negate(); - operator+= (other); + *this += other; negate(); return *this; } @@ -476,24 +476,40 @@ BigInteger& BigInteger::operator-= (const BigInteger& other) BigInteger& BigInteger::operator*= (const BigInteger& other) { - BigInteger total; - highestBit = getHighestBit(); + int n = getHighestBit(); + int t = other.getHighestBit(); + const bool wasNegative = isNegative(); setNegative (false); - for (int i = 0; i <= highestBit; ++i) + BigInteger total; + total.highestBit = n + t + 1; + + n >>= 5; + t >>= 5; + + total.ensureSize (n + t + 2); + + BigInteger m (other); + m.setNegative (false); + + for (int i = 0; i <= t; ++i) { - if (operator[](i)) + uint32 c = 0; + + for (int j = 0; j <= n; ++j) { - BigInteger n (other); - n.setNegative (false); - n <<= i; - total += n; + uint64 uv = (uint64) total.values[i + j] + (uint64) values[j] * (uint64) m.values[i] + (uint64) c; + total.values[i + j] = (uint32) uv; + c = uv >> 32; } + + total.values[i + n + 1] = c; } total.setNegative (wasNegative ^ other.isNegative()); swapWith (total); + return *this; } @@ -683,7 +699,7 @@ void BigInteger::shiftLeft (int bits, const int startBit) if (startBit > 0) { for (int i = highestBit + 1; --i >= startBit;) - setBit (i + bits, operator[] (i)); + setBit (i + bits, (*this) [i]); while (--bits >= 0) clearBit (bits + startBit); @@ -726,7 +742,7 @@ void BigInteger::shiftRight (int bits, const int startBit) if (startBit > 0) { for (int i = startBit; i <= highestBit; ++i) - setBit (i, operator[] (i + bits)); + setBit (i, (*this) [i + bits]); highestBit = getHighestBit(); } @@ -816,25 +832,128 @@ BigInteger BigInteger::findGreatestCommonDivisor (BigInteger n) const void BigInteger::exponentModulo (const BigInteger& exponent, const BigInteger& modulus) { + *this %= modulus; BigInteger exp (exponent); exp %= modulus; - BigInteger value (1); - swapWith (value); - value %= modulus; + if (modulus.getHighestBit() <= 32 || modulus % 2 == 0) + { + BigInteger a (*this); + + const int n = exp.getHighestBit(); + + for (int i = n; --i >= 0;) + { + *this *= *this; - while (! exp.isZero()) + if (exp[i]) + *this *= a; + + if (compareAbsolute (modulus) >= 0) + *this %= modulus; + } + } + else { - if (exp [0]) + const int Rfactor = modulus.getHighestBit() + 1; + BigInteger R (1); + R.shiftLeft (Rfactor, 0); + + BigInteger R1, m1, g; + g.extendedEuclidean (modulus, R, m1, R1); + + if (! g.isOne()) { - operator*= (value); - operator%= (modulus); + BigInteger a (*this); + + for (int i = exp.getHighestBit(); --i >= 0;) + { + *this *= *this; + + if (exp[i]) + *this *= a; + + if (compareAbsolute (modulus) >= 0) + *this %= modulus; + } } + else + { + BigInteger am (((*this) * R) % modulus); + BigInteger xm (am); + BigInteger um (R % modulus); - value *= value; - value %= modulus; - exp >>= 1; + for (int i = exp.getHighestBit(); --i >= 0;) + { + xm.montgomeryMultiplication (xm, modulus, m1, Rfactor); + + if (exp[i]) + xm.montgomeryMultiplication (am, modulus, m1, Rfactor); + } + + xm.montgomeryMultiplication (1, modulus, m1, Rfactor); + swapWith (xm); + } + } +} + +void BigInteger::montgomeryMultiplication (const BigInteger& other, const BigInteger& modulus, + const BigInteger& modulusp, const int k) +{ + *this *= other; + + BigInteger t (*this); + + setRange (k, highestBit - k + 1, false); + *this *= modulusp; + + setRange (k, highestBit - k + 1, false); + *this *= modulus; + *this += t; + shiftRight (k, 0); + + if (compare (modulus) >= 0) + *this -= modulus; + else if (isNegative()) + *this += modulus; +} + +void BigInteger::extendedEuclidean (const BigInteger& a, const BigInteger& b, + BigInteger& x, BigInteger& y) +{ + BigInteger p(a), q(b), gcd(1); + + Array tempValues; + + while (! q.isZero()) + { + tempValues.add (p / q); + gcd = q; + q = p % q; + p = gcd; + } + + x.clear(); + y = 1; + + for (int i = 1; i < tempValues.size(); ++i) + { + const BigInteger& v = tempValues.getReference (tempValues.size() - i - 1); + + if ((i & 1) != 0) + x += y * v; + else + y += x * v; } + + if (gcd.compareAbsolute (y * b - x * a) != 0) + { + x.negate(); + x.swapWith (y); + x.negate(); + } + + swapWith (gcd); } void BigInteger::inverseModulo (const BigInteger& modulus) @@ -846,7 +965,7 @@ void BigInteger::inverseModulo (const BigInteger& modulus) } if (isNegative() || compareAbsolute (modulus) >= 0) - operator%= (modulus); + *this %= modulus; if (isOne()) return; @@ -959,8 +1078,8 @@ void BigInteger::parseString (StringRef text, const int base) if (((uint32) digit) < (uint32) base) { - operator<<= (bits); - operator+= (digit); + *this <<= bits; + *this += digit; } else if (c == 0) { @@ -978,8 +1097,8 @@ void BigInteger::parseString (StringRef text, const int base) if (c >= '0' && c <= '9') { - operator*= (ten); - operator+= ((int) (c - '0')); + *this *= ten; + *this += (int) (c - '0'); } else if (c == 0) { diff --git a/modules/juce_core/maths/juce_BigInteger.h b/modules/juce_core/maths/juce_BigInteger.h index 2f99424a3b..ada0e89e5b 100644 --- a/modules/juce_core/maths/juce_BigInteger.h +++ b/modules/juce_core/maths/juce_BigInteger.h @@ -180,6 +180,22 @@ public: */ int getHighestBit() const noexcept; + //============================================================================== + /** Returns true if the value is less than zero. + @see setNegative, negate + */ + bool isNegative() const noexcept; + + /** Changes the sign of the number to be positive or negative. + @see isNegative, negate + */ + void setNegative (bool shouldBeNegative) noexcept; + + /** Inverts the sign of the number. + @see isNegative, setNegative + */ + void negate() noexcept; + //============================================================================== // All the standard arithmetic ops... @@ -236,6 +252,7 @@ public: */ int compareAbsolute (const BigInteger& other) const noexcept; + //============================================================================== /** Divides this value by another one and returns the remainder. This number is divided by other, leaving the quotient in this number, @@ -243,7 +260,7 @@ public: */ void divideBy (const BigInteger& divisor, BigInteger& remainder); - /** Returns the largest value that will divide both this value and the one passed-in. */ + /** Returns the largest value that will divide both this value and the argument. */ BigInteger findGreatestCommonDivisor (BigInteger other) const; /** Performs a combined exponent and modulo operation. @@ -256,21 +273,20 @@ public: */ void inverseModulo (const BigInteger& modulus); - //============================================================================== - /** Returns true if the value is less than zero. - @see setNegative, negate - */ - bool isNegative() const noexcept; - - /** Changes the sign of the number to be positive or negative. - @see isNegative, negate + /** Performs the Montgomery Multiplication with modulo. + This object is left containing the result value: ((this * other) * R1) % modulus. + To get this result, we need modulus, modulusp and k such as R = 2^k, with + modulus * modulusp - R * R1 = GCD(modulus, R) = 1 */ - void setNegative (bool shouldBeNegative) noexcept; + void montgomeryMultiplication (const BigInteger& other, const BigInteger& modulus, + const BigInteger& modulusp, int k); - /** Inverts the sign of the number. - @see isNegative, setNegative + /** Performs the Extended Euclidean algorithm. + This method will set the xOut and yOut arguments such that (a * xOut) - (b * yOut) = GCD (a, b). + On return, this object is left containing the value of the GCD. */ - void negate() noexcept; + void extendedEuclidean (const BigInteger& a, const BigInteger& b, + BigInteger& xOut, BigInteger& yOut); //============================================================================== /** Converts the number to a string.