Decorator causing my function to not accept positional argument?

Question:

I am learning how to use decorators. Inspired by tqdm library (progress bars), I have cobbled together a decorator, heavily inspired by this code. The goal is to use multithreading to have a progress indicator blink/spin/count in stdout while a long-running function that does not depend on iteration is in progress.

The decorator is changing something about the function it wraps. spendtime should take one positional argument, an integer, passed to sleep():

@progresswrapper(style="clock", msg="testing...")
def spendtime(x):
    result = 'foobar' #dosomething
    sleep(x)
    return result #returnanything

if __name__ == "__main__":
    print(spendtime(2))

But this throws a TypeError: spendtime() takes 0 positional arguments but 1 was given.
When I comment out the decorator, spendtime() works as expected.
When I make the following changes:

@progresswrapper(style="clock", msg="testing...")
def spendtime(x=2):  #### CHANGED TO KWARG
    result = 'foobar' #dosomething
    sleep(x)
    return result #returnanything

if __name__ == "__main__":
    print(spendtime())  ### RUNNING WITH NO INPUT

this code works too, but is not suitable; I want to be able to call spendtime() with an argument. In my work code, I intend to use this decorator on much more complex functions that take a whole range of args and kwargs.

Why is the use of the decorator here causing this behaviour? Why am I seeing this TypeError?

Full code:

import threading
from itertools import cycle
from time import sleep
import functools

# heavy inspiration from duckythescientist: https://github.com/tqdm/tqdm/issues/458

def monitor(func, style="clock", msg=None):
    """
    print progress indicator to same line in stdout while wrapped function progresses
    """

    styles = {"clock":(0.1, ['-', '\', '|', '/']),
              "blink":(0.2, [u"u2022", ' ']),
              "jump":(0.2, [u"u2022", '.']),
              "ellipsis":(0.2, ["   ", ".  ", ".. ", "..."]),
              "shuffle":(0.2, [u"u2022 ", u" u2022"]),
              "counter":(1, (f"{n:02}s" for n in range(100))),
             }

    if style == "counter":
        msg_pad = 10
    else:
        msg_pad = len(styles[style][1][1]) + 7

    if msg:
        msg = (" "*msg_pad) + msg # ensure that progress indocator doesn't overwrite msg
        print(f"{msg}", end='r')
    marker_cycle = cycle(styles[style][1]) # loop through list
    tstep = styles[style][0]

    ret = [None]
    def runningfunc(func, ret, *args, **kwargs):
        ret[0] = func(*args, **kwargs)

    thread = threading.Thread(target=runningfunc, args=(func, ret))

    thread.start()
    while thread.is_alive():
        thread.join(timeout=tstep) # call to runningfunc (?)
        print(f"  [ {next(marker_cycle)} ]", end='r')
    print("done.33[K")
    return ret[0]


def progresswrapper(style, msg):
    def real_decorator(func):
        @functools.wraps(func)
        def wrapper():
            return monitor(func, style=style, msg=msg)
        return wrapper
    return real_decorator





@progresswrapper(style="clock", msg="testing...")
def spendtime(x):
    result = 'foobar' #dosomething
    sleep(x)
    return result #returnanything

if __name__ == "__main__":
    print(spendtime(2))
Asked By: skytwosea

||

Answers:

When you apply the decorator to your spendtime function, the original function gets replaced by the wrapper function that is defined inside of the decorator’s code. wrapper doesn’t take any arguments, and even if it did, there’s no way for it to pass them along to the monitor function where the original function eventually gets called.

One of the many nested functions you’re using, runningfunc does accept *args and **kwargs, but the only place it can ever be called from (the thread), never passes it anything that will be captured in those variables.

So the obvious fix is to change wrapper to accept arguments, and then pass them on via monitor and a thread to runningfunc. You probably don’t need to be using *args and **kwargs for most of those calls, and indeed, doing so would risk complications if the wrapped function happened to have argument names that matched any of your intermediate functions (e.g. style). Passing a tuple and dictionary around is fine though:

def monitor(func, args, kwargs, style="clock", msg=None):   # get args and kwargs
    # style stuff omitted for brevity

    def runningfunc(func, ret, args, kwargs): # don't expect unpacked args and kwargs
        ret[0] = func(*args, **kwargs)

    thread = threading.Thread(target=runningfunc, args=(func, ret, args, kwargs)) # ...

    # wait-loop omitted


def progresswrapper(style, msg):
    def real_decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):                    # accept arbitrary arguments
            return monitor(func, args, kwargs, style=style, msg=msg) # pass them on
        return wrapper
    return real_decorator

It’s worth noting that you could simplify a lot of your argument passing logic by simply putting the body of monitor into wrapper, rather than having them separate. The style, msg and func values could be accessed from the enclosing namespaces, rather than needing to be explicitly passed, and *args and **kwargs are being collected in wrapper just as you need them.

Of course, you could do a similar thing with runningfunc, which has access to all the values defined in the enclosing local namespace of monitor, so you could make it argumentless and let it access func, ret, args and kwargs directly, rather than passing them through the threading startup process. This would even let you replace your ret list with a simple variable, as a nonlocal ret statement would let runningfunc write to the outer namespace just as easily as it can modify the contents of a list.

Answered By: Blckknght

Thank you @blckknght for your advice. Here is the working code for my decorator. It boiled down to simplifying and being careful with how I was passing args through the chain of nested functions.

The next problem is how to extract the stack trace from the threading object (line 72): the printout from this line is only the object, as so:
<traceback object at 0x7f729f17f800>

import threading
from itertools import cycle
from time import sleep
import functools
import sys

# heavy inspiration from duckythescientist: https://github.com/tqdm/tqdm/issues/458

def progresswrapper(style="clock", msg="working..."):
    """
    wrapper for long-running functions: prints an active indicator to stdout until the 
    long-running function terminates.
    """

    def real_decorator(func):
        @functools.wraps(func)
        def wrapper(*args, style=style, msg=msg, **kwargs):

            # error handling: to print an error message via this wrapper,
            # need to raise an exception in wrapped function
            error, error_msg, error_trace = False, None, None
            def wrapper_exc_hook(e):
                nonlocal error, error_msg, error_trace
                error, error_msg, error_trace = True, str(e.exc_value), e

            # define indicator styles here. Must be a tuple: (float, list)
            # float in tuple is the tstep rate for call to threading; as such,
            # the rate of thread switching defines rate at which the indicator symbol changes 
            styles = {"clock": (0.1, ['-', '\', '|', '/']),
                      "blink": (0.2, [u"u2022", ' ']),
                      "jump": (0.2, [u"u2022", '.']),
                      "ellipsis": (0.2, ["   ", ".  ", ".. ", "..."]),
                      "shuffle": (0.2, [u"u2022 ", u" u2022"]),
                      "counter": (1, (f"{n:02}s" for n in range(100))),
                      }

            # string formatting for messages 
            if style == "counter":
                indicator_length = 3
                msg_padding = 10
            else:
                indicator_length = len(styles[style][1][1])
                msg_padding = indicator_length + 7
            padded_msg = (" " * msg_padding) + msg  # ensure that progress indocator doesn't overwrite msg

            marker_cycle = cycle(styles[style][1])  # use itertools to create a generator from the selected list
            tstep = styles[style][0] # sets rate of thread switching
            ret = [None] # wrapped function return value

            # set thread
            def runningfunc(func, ret):
                ret[0] = func(*args, **kwargs)
            threading.excepthook = wrapper_exc_hook
            thread = threading.Thread(target=runningfunc, args=(func, ret))
            
            # execute 
            print(f"{padded_msg}", end='r')
            thread.start()
            while thread.is_alive(): # wrapped function defines lifespan of indicator printout
                print(f"  [ {next(marker_cycle)} ]", end='r')
                thread.join(timeout=tstep)  # call to runningfunc (?)

            # error handling: set closing message
            if not error:
                terminal_symbol = u'u2713'
                print(f"  [ {terminal_symbol:^{indicator_length}} ] {msg} Done.") # leaves message on line in stdout
            else:
                terminal_symbol = '!'
                print(f"  [ {terminal_symbol:^{indicator_length}} ] {error_msg}33[K") # This unicode char will fully clear active line in stdout
                # for item in error_trace:
                #     print(item)
                print(f"{error_trace.exc_type}:n{error_trace.exc_traceback}")
                print(f"{sys.exc_info()}")
            return ret[0]
        return wrapper
    return real_decorator



# test for successful function termination:
page = "x"
@progresswrapper(style="ellipsis", msg=f"reading page {page}...")
def spendtime(x):
    sleep(x) #dosomething
    return 0 #returnanything

# test for exception handling:
@progresswrapper(style="jump") # called with default style and message
def built_to_fail(x):
    try:
        letters = ["y", "z"]
        for i in range(x):
            z = letters[i]
            sleep(1)
        return 0
    except BaseException as error:
        raise Exception("problemski")

# go time
if __name__ == "__main__":
    # spendtime(3)
    built_to_fail(3)
Answered By: skytwosea