Efficient finding primitive roots modulo n using Python?

Question:

I’m using the following code for finding primitive roots modulo n in Python:

Code:

def gcd(a,b):
    while b != 0:
        a, b = b, a % b
    return a

def primRoots(modulo):
    roots = []
    required_set = set(num for num in range (1, modulo) if gcd(num, modulo) == 1)

    for g in range(1, modulo):
        actual_set = set(pow(g, powers) % modulo for powers in range (1, modulo))
        if required_set == actual_set:
            roots.append(g)           
    return roots

if __name__ == "__main__":
    p = 17
    primitive_roots = primRoots(p)
    print(primitive_roots)

Output:

[3, 5, 6, 7, 10, 11, 12, 14]   

Code fragment extracted from: Diffie-Hellman (Github)


Can the primRoots method be simplified or optimized in terms of memory usage and performance/efficiency?

Asked By: Erba Aitbayev

||

Answers:

One quick change that you can make here (not efficiently optimum yet) is using list and set comprehensions:

def primRoots(modulo):
    coprime_set = {num for num in range(1, modulo) if gcd(num, modulo) == 1}
    return [g for g in range(1, modulo) if coprime_set == {pow(g, powers, modulo)
            for powers in range(1, modulo)}]

Now, one powerful and interesting algorithmic change that you can make here is to optimize your gcd function using memoization. Or even better you can simply use built-in gcd function form math module in Python-3.5+ or fractions module in former versions:

from functools import wraps
def cache_gcd(f):
    cache = {}

    @wraps(f)
    def wrapped(a, b):
        key = (a, b)
        try:
            result = cache[key]
        except KeyError:
            result = cache[key] = f(a, b)
        return result
    return wrapped

@cache_gcd
def gcd(a,b):
    while b != 0:
        a, b = b, a % b
    return a
# or just do the following (recommended)
# from math import gcd

Then:

def primRoots(modulo):
    coprime_set = {num for num in range(1, modulo) if gcd(num, modulo) == 1}
    return [g for g in range(1, modulo) if coprime_set == {pow(g, powers, modulo)
            for powers in range(1, modulo)}]

As mentioned in comments, as a more pythoinc optimizer way you can use fractions.gcd (or for Python-3.5+ math.gcd).

Answered By: Mazdak

Based on the comment of Pete and answer of Kasramvd, I can suggest this:

from math import gcd as bltin_gcd

def primRoots(modulo):
    required_set = {num for num in range(1, modulo) if bltin_gcd(num, modulo) }
    return [g for g in range(1, modulo) if required_set == {pow(g, powers, modulo)
            for powers in range(1, modulo)}]

print(primRoots(17))

Output:

[3, 5, 6, 7, 10, 11, 12, 14]

Changes:

  • It now uses pow method’s 3-rd argument for the modulo.
  • Switched to gcd built-in function that’s defined in math (for Python 3.5) for a speed boost.

Additional info about built-in gcd is here: Co-primes checking

Answered By: Erba Aitbayev

In the special case that p is prime, the following is a good bit faster:

import sys

# translated to Python from http://www.bluetulip.org/2014/programs/primitive.js
# (some rights may remain with the author of the above javascript code)

def isNotPrime(possible):
    # We only test this here to protect people who copy and paste
    # the code without reading the first sentence of the answer.
    # In an application where you know the numbers are prime you
    # will remove this function (and the call). If you need to
    # test for primality, look for a more efficient algorithm, see
    # for example Joseph F's answer on this page.
    i = 2
    while i*i <= possible:
        if (possible % i) == 0:
            return True
        i = i + 1
    return False

def primRoots(theNum):
    if isNotPrime(theNum):
        raise ValueError("Sorry, the number must be prime.")
    o = 1
    roots = []
    r = 2
    while r < theNum:
        k = pow(r, o, theNum)
        while (k > 1):
            o = o + 1
            k = (k * r) % theNum
        if o == (theNum - 1):
            roots.append(r)
        o = 1
        r = r + 1
    return roots

print(primRoots(int(sys.argv[1])))
Answered By: Joachim Wagner

You can greatly improve your isNotPrime function by using a more efficient algorithm. You could double the speed by doing a special test for even numbers and then only testing odd numbers up to the square root, but this is still very inefficient compared to an algorithm such as the Miller Rabin test. This version in the Rosetta Code site will always give the correct answer for any number with fewer than 25 digits or so. For large primes, this will run in a tiny fraction of the time it takes to use trial division.

Also, you should avoid using the floating point exponentiation operator ** when you are dealing with integers as in this case (even though the Rosetta code that I just linked to does the same thing!). Things might work fine in a particular case, but it can be a subtle source of error when Python has to convert from floating point to integers, or when an integer is too large to represent exactly in floating point. There are efficient integer square root algorithms that you can use instead. Here’s a simple one:

def int_sqrt(n):
   if n == 0:
      return 0
   x = n
   y = (x + n//x)//2

   while (y<x):
      x=y
      y = (x + n//x)//2

   return x
Answered By: Joseph F

Those codes are all in-efficient, in many ways, first of all you do not need to iterate for all co-prime reminders of n, you need to check only for powers that are dividers of Euler’s function from n. In the case n is prime Euler’s function is n-1. If n i prime, you need to factorize n-1 and make check with only those dividers, not all. There is a simple mathematics behind this.

Second. You need better function for powering a number imagine the power is too big, I think in python you have the function pow(g, powers, modulo) which at each steps makes division and getting the remainder only ( _ % modulo ).

If you are going to implement the Diffie-Hellman algorithm it is better to use safe primes. They are such primes that p is a prime and 2p+1 is also prime, so that 2p+1 is called safe prime. If you get n = 2*p+1, then the dividers for that n-1 (n is prime, Euler’s function from n is n-1) are 1, 2, p and 2p, you need to check only if the number g at power 2 and g at power p if one of them gives 1, then that g is not primitive root, and you can throw that g away and select another g, the next one g+1, If g^2 and g^p are non equal to 1 by modulo n, then that g is a primitive root, that check guarantees, that all powers except 2p would give numbers different from 1 by modulo n.

The example code uses Sophie Germain prime p and the corresponding safe prime 2p+1, and calculates primitive roots of that safe prime 2p+1.

You can easily re-work the code for any prime number or any other number, by adding a function to calculate Euler’s function and to find all divisors of that value. But this is only a demo not a complete code. And there might be better ways.

class SGPrime :
    '''
    This object expects a Sophie Germain prime p, it does not check that it accept that as input.
    Euler function from any prime is n-1, and the order (see method get_order) of any co-prime 
    remainder of n could be only a divider of Euler function value.  
    '''
    def __init__(self, pSophieGermain ):
        self.n = 2*pSophieGermain+1
        #TODO! check if pSophieGermain is prime 
        #TODO! check if n is also prime.
        #They both have to be primes, elsewhere the code does not work!

        # Euler's function is n-1, #TODO for any n, calculate Euler's function from n
        self.elrfunc = self.n-1
    
        # All divisors of Euler's function value, #TODO for any n, get all divisors of the Euler's function value.
        self.elrfunc_divisors = [1, 2, pSophieGermain, self.elrfunc]
    
        
    def get_order(self, r):
        ''' 
        Calculate the order of a number, the minimal power at which r would be congruent with 1 by modulo p.
        '''
        r = r % self.n
        for d in self.elrfunc_divisors:
           if ( pow( r, d, self.n) == 1 ):
               return d
        return 0 # no such order, not possible if n is prime, - see small Fermat's theorem
    
    def is_primitive_root(self, r):
        '''
        Check if r is a primitive root by modulo p. Such always exists if p is prime.
        '''
        return ( self.get_order(r) == self.elrfunc )
    
    def find_all_primitive_roots(self, max_num_of_roots = None):
        '''
        Find all primitive roots, only for demo if n is large the list is large for DH or any other such algorithm 
        better to stop at first primitive roots.
        '''
        primitive_roots = []
        for g in range(1, self.n):
            if ( self.is_primitive_root(g) ):
                primitive_roots.append(g)
                if (( max_num_of_roots != None ) and (len(primitive_roots) >= max_num_of_roots)):
                    break
        return primitive_roots

#demo, Sophie Germain's prime
p =  20963
sggen = SGPrime(p)
print (f"Safe prime : {sggen.n}, and primitive roots of {sggen.n} are : " )
print(sggen.find_all_primitive_roots())

Regards

Answered By: Ognyan Gerassimov