Efficient algorithm to calculate the most right non-zero digit of a number's factorial in Python

Question:

Calculate the most right non-zero digit of n factorial efficiently

I want to calculate the right most digit of a given number’s factorial and print it. What I’ve done so far is:

import math
n = int(input())
fact = math.factorial(n)
print(str(fact).rstrip('0')[-1])

but I still get time limits and I look for faster solutions.
It’s worth noting that I must use python to solve this problem.
Also, I shall point out that n is from 1 to 65536, the time limit is 0.5 seconds and I have 256 megabytes of memory.

Asked By: Shayan

||

Answers:

There is a neat recursive formula you can use: let D(n) be the last non-zero digit in n!

  • If n<10, use a lookup table
  • If the second last digit of n is odd, D(n) = 4 * D(n//5) * D(unit digit of n)
  • If the second last digit of n is even, D(n) = 6 * D(n//5) * D(Unit digit of n)

See this math stackexchange post for a proof.

Translating it into code:

def last_nonzero_factorial_digit(n):
    lookup = [1, 1, 2, 6, 4, 2, 2, 4, 2, 8]
    if n < 10:
        return lookup[n]

    if ((n // 10) % 10) % 2 == 0:
        return (6 * last_nonzero_factorial_digit(n // 5) * lookup[n % 10]) % 10
    else:
        return (4 * last_nonzero_factorial_digit(n // 5) * lookup[n % 10]) % 10

On my laptop, this version runs ~14,000 times faster on a 5-digit number.

Answered By: Seon

A simple fast enough one, just compute the factorial and then remove as many zeros as how often prime factor 5 occurs (since that’s what causes zeros, together with the more frequent prime factor 2). Every fifth number from 1 to n contributes a prime factor 5, every fifth of those contributes another, etc.

def Kelly(n):
    res = math.factorial(n)
    while n:
        n //= 5
        res //= 10**n
    return res % 10

Benchmark with your limit n=65536:

0.108791 seconds  just_factorial
1.434466 seconds  Shayan_original
0.553055 seconds  Shayan_answer
0.000016 seconds  Seon
0.208012 seconds  Kelly
0.029500 seconds  Kelly2

We see that just computing the factorial, without extracting the desired digit, is easily fast enough. Just your original extraction of the digit via string is slow.

(Note: OP Shayan said their answer’s solution "works for me and gets the job done" and mine is faster, that’s mainly why I said mine is fast enough (also because it’s well under 0.5 s for me). Looks like they only deleted theirs because they couldn’t explain it.)

Another simple and faster way:

def Kelly2(n):
    prod = 1
    twos = 0
    for i in range(2, n + 1):
        while not i % 2:
            i //= 2
            twos += 1
        while not i % 5:
            i //= 5
            twos -= 1
        prod = prod * i % 10
    return prod * pow(2, twos, 10) % 10

Here we multiply the numbers from 1 to n ourselves. But extract prime factors 2 and 5 and count them separately. Then n! = 2twos * 5fives * ProductOfReducedFactors. Since pairs of 2 and 5 cause the unwanted trailing zeros, instead count how many 2s we have that we can’t pair with 5s. That’s what my code does. Then nFactorialWithoutTrailingZeros = 2twos * ProductOfReducedFactors. And we get its last digit with % 10, which we can use throughout the calculation to keep ProductOfReducedFactors small.

Testing code (Attempt This Online!):

import math
from time import time


def just_factorial(n):
    math.factorial(n)
    return -1


def Shayan_original(n):
    fact = math.factorial(n)
    return str(fact).rstrip('0')[-1]


def Shayan_answer(n):
    a = n // 5
    b = n - 5 * a
    fact_a = math.factorial(a)
    fact_b = math.factorial(b)
    power_a = 2 ** a
    res = fact_a * fact_b * power_a
    while (res % 10 == 0):
        res //= 10
    return int(res % 10)


def Seon(n):
    lookup = [1, 1, 2, 6, 4, 2, 2, 4, 2, 8]
    if n < 10:
        return lookup[n]
    if ((n // 10) % 10) % 2 == 0:
        return (6 * Seon(n // 5) * lookup[n % 10]) % 10
    else:
        return (4 * Seon(n // 5) * lookup[n % 10]) % 10


def Kelly(n):
    res = math.factorial(n)
    while n:
        n //= 5
        res //= 10**n
    return res % 10


def Kelly2(n):
    prod = 1
    twos = 0
    for i in range(2, n + 1):
        while not i % 2:
            i //= 2
            twos += 1
        while not i % 5:
            i //= 5
            twos -= 1
        prod = prod * i % 10
    return prod * pow(2, twos, 10) % 10


funcs = just_factorial, Shayan_original, Shayan_answer, Seon, Kelly, Kelly2

# Correctness
for n in *range(1001), 65536:
    expect = funcs[1](n)
    for f in funcs[2:]:
        result = str(f(n))
        assert result == expect

# Speed
for f in funcs:
    t = time()
    f(65536)
    print(f'{time() - t :8.6f} seconds ', f.__name__)
Answered By: Kelly Bundy

A solution based on your answer (just repeating your reduction step) an optimized version of it that beats Seon’s, and an optimized version of Seon’s. Times for n=65536:

  2.04 ± 0.03 μs  Seon_optimized
  2.54 ± 0.04 μs  Kelly3_optimized
  4.18 ± 0.04 μs  Seon
 20.51 ± 0.09 μs  Kelly3

Code (Attempt This Online!):

def Kelly3(n):
    res = 1
    while n:
        res *= factorial(n % 5)
        n //= 5
        res <<= n
    return res % 10


def Kelly3_optimized(n):
    res = 1
    twos = 0
    while n:
        res *= (1, 1, 2, 6, 24)[n % 5]
        n //= 5
        twos += n
    shift = twos and (twos % 4 or 4)
    return (res << shift) % 10


def Seon_optimized(n):
    lookup = 1, 1, 2, 6, 4, 2, 2, 4, 2, 8
    Lookup = 6, 6, 2, 6, 4, 2, 2, 4, 2, 8, 4, 4, 8, 4, 6, 8, 8, 6, 8, 2
    res = 1
    while n >= 10:
        res *= Lookup[n % 20]
        n //= 5
    return res * lookup[n] % 10


def Seon(n):
    lookup = [1, 1, 2, 6, 4, 2, 2, 4, 2, 8]
    if n < 10:
        return lookup[n]
    if ((n // 10) % 10) % 2 == 0:
        return (6 * Seon(n // 5) * lookup[n % 10]) % 10
    else:
        return (4 * Seon(n // 5) * lookup[n % 10]) % 10


from timeit import timeit
from statistics import mean, stdev
from math import factorial

funcs = Seon, Kelly3, Kelly3_optimized, Seon_optimized

# Correctness
for n in *range(10001), 65536:
    expect = str(funcs[0](n))
    for f in funcs[1:]:
        result = str(f(n))
        assert result == expect

# Speed
times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e6 for t in sorted(times[f])[:100]]
    return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} μs '
for _ in range(1000):
    for f in funcs:
        t = timeit(lambda: f(65536), number=1)
        times[f].append(t)
for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)
Answered By: Kelly Bundy