How to find out if (the source code of) a function contains a loop?

Question:

Let’s say, I have a bunch of functions a, b, c, d and e and I want to find out if they directly use a loop:

def a():
    for i in range(3):
        print(i**2)

def b():
    i = 0
    while i < 3:
        print(i**2)
        i += 1

def c():
    print("n".join([str(i**2) for i in range(3)]))

def d():
    print("n".join(["0", "1", "4"]))

def e():
    "for"

I want to write a function uses_loop so I can expect these assertions to pass:

assert uses_loop(a) == True
assert uses_loop(b) == True
assert uses_loop(c) == False
assert uses_loop(d) == False
assert uses_loop(e) == False

(I expect uses_loop(c) to return False because c uses a list comprehension instead of a loop.)

I can’t modify a, b, c, d and e. So I thought it might be possible to use ast for this and walk along the function’s code which I get from inspect.getsource. But I’m open to any other proposals, this was only an idea how it could work.

This is as far as I’ve come with ast:

def uses_loop(function):
    import ast
    import inspect
    nodes = ast.walk(ast.parse(inspect.getsource(function)))
    for node in nodes:
        print(node.__dict__)
Asked By: finefoot

||

Answers:

If you are just trying to check if the function body contains the keywords ‘for’ or ‘while’, you can do the following:

def uses_loop(func_name):
    import inspect
    lines = inspect.getsource(func_name)
    return 'for' in lines or 'while' in lines
Answered By: Rafiul Sabbir

You were almost there! All you had to do was to find out how to get the data from the body objects. They are all attributes after all of some Node type. I just used getattr(node, 'body', []) to get the children and if any of them are of _ast.For or _ast.While return a True.

Note: I was just tinkering around the code. Not sure if this is documented somewhere and can be relied upon. I guess may be you can look it up? 🙂

def a():
    for i in range(3):
        print(i**2)

def b():
    i = 0
    while i < 3:
        print(i**2)
        i += 1

def c():
    print("n".join([str(i**2) for i in range(3)]))

def d():
    print("n".join(["0", "1", "4"]))

def uses_loop(function):
    import ast
    import _ast
    import inspect
    nodes = ast.walk(ast.parse(inspect.getsource(function)))
    return any(isinstance(node, (_ast.For, _ast.While)) for node in nodes)


print(uses_loop(a))    # True
print(uses_loop(b))    # True
print(uses_loop(c))    # False
print(uses_loop(d))    # False
Answered By: UltraInstinct

Check if the function’s abstract syntaxt tree (AST) has any ast.For or ast.While or ast.AsyncFor nodes. Use ast.walk() to visit every node of the AST:

import ast
import inspect

def uses_loop(function):
    loop_statements = ast.For, ast.While, ast.AsyncFor

    nodes = ast.walk(ast.parse(inspect.getsource(function)))
    return any(isinstance(node, loop_statements) for node in nodes)

See the documentation for ast for details. async for was added in Python 3.5.

Answered By: Boris Verkhovskiy