Sieve Of Atkin Implementation in Python

Question:

I am trying to implement the algorithm of Sieve of Atkin given in Wikipedia Link as below:

Sieve Of Atkin

What I’ve tried so far is the implementation in Python given by following Code:

import math
is_prime = list()
limit = 100
for i in range(5,limit):
    is_prime.append(False)

for x in range(1,int(math.sqrt(limit))+1):
    for y in range(1,int(math.sqrt(limit))+1):
        n = 4*x**2 + y**2

        if n<=limit and (n%12==1 or n%12==5):
            # print "1st if"
            is_prime[n] = not is_prime[n]
        n = 3*x**2+y**2
        if n<= limit and n%12==7:
            # print "Second if"
            is_prime[n] = not is_prime[n]
        n = 3*x**2 - y**2
        if x>y and n<=limit and n%12==11:
            # print "third if"
            is_prime[n] = not is_prime[n]

for n in range(5,int(math.sqrt(limit))):
    if is_prime[n]:
        for k in range(n**2,limit+1,n**2):
            is_prime[k] = False
print 2,3
for n in range(5,limit):
    if is_prime[n]: print n

Now I get error as

is_prime[n] = not is_prime[n]
IndexError: list index out of range

this means that I am accessing the value in list where the index is greater than length of List. Consider the Condition when x,y = 100, then of-course the condition n=4x^2+y^2 will give value which is greater than length of list. Am I doing something wrong here? Please help.

EDIT 1
As suggested by Gabe, using

is_prime = [False] * (limit + 1)

insted of :

for i in range(5,limit):
    is_prime.append(False)

did solved the problem.

Asked By: Mahadeva

||

Answers:

You problem is that your limit is 100, but your is_prime list only has limit-5 elements in it due to being initialized with range(5, limit).

Since this code assumes it can access up to limit index, you need to have limit+1 elements in it: is_prime = [False] * (limit + 1)

Note that it doesn’t matter that 4x^2+y^2 is greater than limit because it always checks n <= limit.

Answered By: Gabe

Here is a solution

import math

def sieveOfAtkin(limit):
    P = [2,3]
    sieve=[False]*(limit+1)
    for x in range(1,int(math.sqrt(limit))+1):
        for y in range(1,int(math.sqrt(limit))+1):
            n = 4*x**2 + y**2
            if n<=limit and (n%12==1 or n%12==5) : sieve[n] = not sieve[n]
            n = 3*x**2+y**2
            if n<= limit and n%12==7 : sieve[n] = not sieve[n]
            n = 3*x**2 - y**2
            if x>y and n<=limit and n%12==11 : sieve[n] = not sieve[n]
    for x in range(5,int(math.sqrt(limit))):
        if sieve[x]:
            for y in range(x**2,limit+1,x**2):
                sieve[y] = False
    for p in range(5,limit):
        if sieve[p] : P.append(p)
    return P

print sieveOfAtkin(100)
Answered By: Zsolt KOVACS

This is the optimized implementation proposed by Zsolt KOVACS:

    import math
    import sys
    
    def sieveOfAtkin(limit):
        P = [2,3]
        r = range(1,int(math.sqrt(limit))+1)
        sieve=[False]*(limit+1)
        for x in r:
            for y in r:
                xx=x*x
                yy=y*y
                xx3 = 3*xx
                n = 4*xx + yy
                if n<=limit and (n%12==1 or n%12==5) : sieve[n] = not sieve[n]
                n = xx3 + yy
                if n<=limit and n%12==7 : sieve[n] = not sieve[n]
                n = xx3 - yy
                if x>y and n<=limit and n%12==11 : sieve[n] = not sieve[n]
        for x in range(5,int(math.sqrt(limit))):
            if sieve[x]:
                xx=x*x
                for y in range(xx,limit+1,xx):
                    sieve[y] = False
        for p in range(5,limit):
            if sieve[p] : P.append(p)
        return P
    
    primes = sieveOfAtkin(int(sys.argv[1]))    
    print (primes)

You pass upper limit as the first argument. This program runs in about 6s on my machine comparing to the original which runs in 21s for 10 million limit. What I did:

  • replaced exponentiation with multiplication
  • precalculated some multiplications
Answered By: bergee

Thanks for very interesting question!

Because bugs in your code are already fixed by other answers, so I decided to implement from scratch my own very optimized versions of Sieve of Atkin and also Sieve of Eratosthenes (for comparison).

It appears that out of 4 functions that I implemented best one is 193 times faster than your original code, incredible speedup! If on my slow laptop 10 Million limit takes 50 seconds in your code, same limit takes just 0.26 seconds in my function.

Best speedup in my code is achieved with help of Numba and Numpy packages. They are only used to speedup run of code (with the help of pre-compilation), but I didn’t use any scientific functions from these packages.

First I’ll show timings, this is console output if you run my code located at end of post.

Limit 10_000_000
SieveOfAtkin_Mahadeva           : time  50.513 sec, boost    1.00x
SieveOfAtkin_bergee             : time  13.016 sec, boost    3.88x
SieveOfEratosthenes_Arty_Python : time   5.768 sec, boost    8.76x
SieveOfAtkin_Arty_Python        : time   3.632 sec, boost   13.91x
SieveOfEratosthenes_Arty_Numba  : time   0.445 sec, boost  113.51x
SieveOfAtkin_Arty_Numba         : time   0.261 sec, boost  193.54x

Besides my 4 functions, I used original Questioner’s code of @Mahadeva and best (speed wise) answer’s code of @bergee.

I did 2 versions of Atkin function (first in pure Python, second with use of Numba and Numpy), and 2 versions of Sieve of Eratosthenes (same, one in pure Python, anothe with Numba/Numpy).

Also I did following optimizations, you may find them out if you read official Atkin Wiki (pseudocode section):

  1. Instead of computing X and Y loops both till Sqrt(limit) with step 1, according to Wiki you may do step of 2 in half of loops.

  2. Also if you do some homework math then it is easy to see that X loop doesn’t need to be run till Sqrt(limit), but instead first X loop can be run till Sqrt(limit / 4), second X loop till Sqrt(limit / 3), third X loop till Sqrt(limit / 2). These all can be derived from reversing a term 4 * x * x and 3 * x * x.

  3. Also Wikipedia says that it is not needed to process ALL values of n % 12 == 1 or n % 12 == 5 and n % 12 == 7 and n % 12 == 11, but only 30-50% smaller subset of reminders modulus 60. My function SieveOfAtkin_Arty_Numba() (and Wiki too) shows which reminders to use.

  4. Instead of keeping array is_prime[] of bool values or byte values (in case of Numpy), it is enough to keep an array of bits. This will reduce memory usage exactly 8 times. This not only boost computation by usage of CPU Cache, but also allows to compute many more primes if you have a limited memory. Two Numba versions do bit arithmetics in order to work with bit array.

  5. Numba pre-compilation does most of optimization work. Because it converts code to LLVM Intermediate Representation, which is a kind of Assembly code, which is similar in speed to optimized C/C++ code. Basically due to help of Numba code becomes as fast as if you have written it in pure C/C++ and not Python. But still it is Python code, but automatically optimized by Numba.

Put a look at function SieveOfAtkin_Arty_Python() – this function is basically what you want to look at to study my code. It is in pure Python (without Numba), but 13.9 times faster than your original code, and 3.58 times faster speed wise than best other answer of @bergee.

If you don’t want heavy Numba in your projects then it is best to copy-paste code of SieveOfAtkin_Arty_Python() function, it is best one out of pure Python solutions.

Before running a code do one time only install of packages python -m pip install numba numpy -U. If you don’t like Numba, remove Numba and Numpy packages import from my first line of code and also delete two functions SieveOfEratosthenes_Arty_Numba() and SieveOfAtkin_Arty_Numbda().

Try it online!

import numba as nb, numpy as np, math, time

def SieveOfAtkin_Arty_Python(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Atkin
    end = limit + 1
    sqrt_end = int(end ** 0.5 + 2.01)
    primes = [2, 3]
    is_prime = [False] * end
    for x in range(1, int((end / 4) ** 0.5 + 2.01)):
        xx4 = 4 * x * x
        for y in range(1, sqrt_end, 2):
            n = xx4 + y * y
            if n >= end:
                break
            if n % 12 == 1 or n % 12 == 5:
                is_prime[n] = not is_prime[n]
    for x in range(1, int((end / 3) ** 0.5 + 2.01), 2):
        xx3 = 3 * x * x
        for y in range(2, sqrt_end, 2):
            n = xx3 + y * y
            if n >= end:
                break
            if n % 12 == 7:
                is_prime[n] = not is_prime[n]
    for x in range(2, int((end / 2) ** 0.5 + 2.01)):
        xx3 = 3 * x * x
        for y in range(x - 1, 0, -2):
            n = xx3 - y * y
            if n >= end:
                break
            if n % 12 == 11:
                is_prime[n] = not is_prime[n]
    for x in range(5, sqrt_end):
        if is_prime[x]:
            for y in range(x * x, end, x * x):
                is_prime[y] = False
    for p in range(5, end, 2):
        if is_prime[p]:
            primes.append(p)
    return primes

def SieveOfEratosthenes_Arty_Python(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
    end = limit + 1
    composite = [False] * end
    for i in range(3, int(end ** 0.5 + 2.01)):
        if not composite[i]:
            for j in range(i * i, end, i):
                composite[j] = True
    return [2] + [i for i in range(3, end, 2) if not composite[i]]

@nb.njit(cache = True)
def SieveOfEratosthenes_Arty_Numba(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
    end = limit + 1
    composite = np.zeros(((end + 7) // 8,), dtype = np.uint8)
    for i in range(3, int(end ** 0.5 + 2.01)):
        if not (composite[i // 8] & (1 << (i % 8))):
            for j in range(i * i, end, i):
                composite[j // 8] |= 1 << (j % 8)
    return np.array([2] + [i for i in range(3, end, 2)
        if not (composite[i // 8] & (1 << (i % 8)))], dtype = np.uint32)

@nb.njit(cache = True)
def SieveOfAtkin_Arty_Numba(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Atkin
    # https://github.com/mccricardo/sieve_of_atkin/blob/master/sieve_of_atkin.py
    # https://stackoverflow.com/questions/21783160/
    end = limit + 1
    is_prime = np.zeros(((end + 7) // 8,), dtype = np.uint8)
    # Subset of n % 12 == 1 or n % 12 == 5
    set0 = np.array([int(i in {1, 13, 17, 29, 37, 41, 49, 53}) for i in range(60)], dtype = np.uint8)
    # Subset of n % 12 == 7
    set1 = np.array([int(i in {7, 19, 31, 43}) for i in range(60)], dtype = np.uint8)
    # Subset of n % 12 == 11
    set2 = np.array([int(i in {11, 23, 47, 59}) for i in range(60)], dtype = np.uint8)
    sqrt_end = int(math.sqrt(end) + 1.01)
    
    for x in range(1, int(sqrt_end / math.sqrt(4) + 2.01)):
        xx4 = 4 * x * x
        for y in range(1, sqrt_end, 2):
            n = xx4 + y * y
            if n >= end:
                break
            if set0[n % 60]:
                is_prime[n // 8] ^= 1 << (n % 8)
    for x in range(1, int(sqrt_end / math.sqrt(3) + 2.01), 2):
        xx3 = 3 * x * x
        for y in range(2, sqrt_end, 2):
            n = xx3 + y * y
            if n >= end:
                break
            if set1[n % 60]:
                is_prime[n // 8] ^= 1 << (n % 8)
    for x in range(2, int(sqrt_end / math.sqrt(2) + 2.01)):
        xx3 = 3 * x * x
        for y in range(x - 1, 0, -2):
            n = xx3 - y * y
            if n >= end:
                break
            if set2[n % 60]:
                is_prime[n // 8] ^= 1 << (n % 8)
    
    for n in range(7, sqrt_end):
        if is_prime[n // 8] & (1 << (n % 8)):
            for k in range(n * n, end, n * n):
                is_prime[k // 8] &= ~np.uint8(1 << (k % 8))
    
    return np.array([2, 3, 5] + [n for n in range(7, end, 2)
        if is_prime[n // 8] & (1 << (n % 8))], dtype = np.uint32)
        
def SieveOfAtkin_Mahadeva(limit):
    # https://stackoverflow.com/q/21783160/941531
    
    is_prime = [False] * (limit + 1)
    
    for x in range(1,int(math.sqrt(limit))+1):
        for y in range(1,int(math.sqrt(limit))+1):
            n = 4*x**2 + y**2

            if n<=limit and (n%12==1 or n%12==5):
                # print "1st if"
                is_prime[n] = not is_prime[n]
            n = 3*x**2+y**2
            if n<= limit and n%12==7:
                # print "Second if"
                is_prime[n] = not is_prime[n]
            n = 3*x**2 - y**2
            if x>y and n<=limit and n%12==11:
                # print "third if"
                is_prime[n] = not is_prime[n]

    for n in range(5,int(math.sqrt(limit))):
        if is_prime[n]:
            for k in range(n**2,limit+1,n**2):
                is_prime[k] = False
    return [2,3] + [n for n in range(5,limit) if is_prime[n]]

def SieveOfAtkin_bergee(limit):
    # https://stackoverflow.com/a/71490622/941531
    P = [2,3]
    r = range(1,int(math.sqrt(limit))+1)
    sieve=[False]*(limit+1)
    for x in r:
        for y in r:
            xx=x*x
            yy=y*y
            xx3 = 3*xx
            n = 4*xx + yy
            if n<=limit and (n%12==1 or n%12==5) : sieve[n] = not sieve[n]
            n = xx3 + yy
            if n<=limit and n%12==7 : sieve[n] = not sieve[n]
            n = xx3 - yy
            if x>y and n<=limit and n%12==11 : sieve[n] = not sieve[n]
    for x in range(5,int(math.sqrt(limit))):
        if sieve[x]:
            xx=x*x
            for y in range(xx,limit+1,xx):
                sieve[y] = False
    for p in range(5,limit):
        if sieve[p] : P.append(p)
    return P

def Test():
    limit = 5 * 10 ** 6
    # Do pretty printing of limit
    print(f'Limit', ''.join(reversed(''.join([['', '_'][i > 0 and i % 3 == 0] + c for i, c in enumerate(reversed(str(limit)))]))))
    rtim, rres = None, None
    for f in [
        SieveOfAtkin_Mahadeva,
        SieveOfAtkin_bergee, 
        SieveOfEratosthenes_Arty_Python,
        SieveOfAtkin_Arty_Python,
        SieveOfEratosthenes_Arty_Numba,
        SieveOfAtkin_Arty_Numba,
    ]:
        fname = f.__name__
        print(f'{fname:<31} : ', end = '', flush = True)
        f(1 << 10) # Pre-compute function, Numba needs it for pre-compilation
        tim = time.time()
        res = np.array(f(limit), dtype = np.uint32)
        tim = time.time() - tim
        if rtim is None:
            rtim = tim
        if rres is None:
            rres = res
        else:
            assert np.all(rres == res)
        print(f'time {tim:>7.3f} sec, boost {rtim / tim:>7.2f}x', flush = True)
        
if __name__ == '__main__':
    Test()
Answered By: Arty
Categories: questions Tags:
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.