How can we "associate" a Python context manager to the variables appearing in its block?

Question:

As I understand it, context managers are used in Python for defining initializing and finalizing pieces of code (__enter__ and __exit__) for an object.

However, in the tutorial for PyMC3 they show the following context manager example:

basic_model = pm.Model()

with basic_model:

    # Priors for unknown model parameters
    alpha = pm.Normal('alpha', mu=0, sd=10)
    beta = pm.Normal('beta', mu=0, sd=10, shape=2)
    sigma = pm.HalfNormal('sigma', sd=1)

    # Expected value of outcome
    mu = alpha + beta[0]*X1 + beta[1]*X2

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

and mention that this has the purpose of associating the variables alpha, beta, sigma, mu and Y_obs to the model basic_model.

I would like to understand how such a mechanism works. In the explanations of context managers I have found, I did not see anything suggesting how variables or objects defined within the context’s block get somehow “associated” to the context manager. It would seem that the library (PyMC3) somehow has access to the “current” context manager so it can associate each newly created statement to it behind the scenes. But how can the library get access to the context manager?

Asked By: user118967

||

Answers:

I don’t know how it works in this specific case, but in general you will use some ‘behind the scenes magic’:

class Parent:
    def __init__(self):
        self.active_child = None

    def ContextManager(self):
        return Child(self)

    def Attribute(self):
        return self.active_child.Attribute()

class Child:
    def __init__(self,parent):
        self.parent = parent

    def __enter__(self):
        self.parent.active_child = self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.parent.active_child = None

    def Attribute(self):
        print("Called Attribute of child")

Using this code:

p = Parent()
with p.ContextManager():
    attr = p.Attribute()

will yield to following output:

Called Attribute of child
Answered By: MegaIng

PyMC3 does this by maintaining a thread local variable as a class variable inside the Context class. Models inherit from Context.

Each time you call with on a model, the current model gets pushed onto the thread-specific context stack. The top of the stack thus always refers to the innermost (most recent) model used as a context manager.

Contexts (and thus Models) have a .get_context() class method to obtain the top of the context stack.

Distributions call Model.get_context() when they are created to associate themselves with the innermost model.

So in short:

  1. with model pushes model onto the context stack. This means that inside of the with block, type(model).contexts or Model.contexts, or Context.contexts now contain model as its last (top-most) element.
  2. Distribution.__init__() calls Model.get_context() (note capital M), which returns the top of the context stack. In our case this is model. The context stack is thread-local (there is one per thread), but it is not instance-specific. If there is only a single thread, there also is only a single context stack, regardless of the number of models.
  3. When exiting the context manager. model gets popped from the context stack.
Answered By: dhke

One can also inspect the stack for locals() variables when entering and exiting the context manager block and identify which one have changed.

class VariablePostProcessor(object):
    """Context manager that applies a function to all newly defined variables in the context manager.

    with VariablePostProcessor(print):
        a = 1
        b = 3

    It uses the (name, id(obj)) of the variable & object to detect if a variable has been added.
    If a name is already binded before the block to an object, it will detect the assignment to this name
    in the context manager block only if the id of the object has changed.

    a = 1
    b = 2
    with VariablePostProcessor(print):
        a = 1
        b = 3
    # will only detect 'b' has newly defined variable/object. 'a' will not be detected as it points to the
    # same object 1
    """

    @staticmethod
    def variables():
        # get the locals 2 stack above
        # (0 is this function, 1 is the __init__/__exit__ level, 2 is the context manager level)
        return {(k, id(v)): v for k, v in inspect.stack()[2].frame.f_locals.items()}

    def __init__(self, post_process):
        self.post_process = post_process
        # save the current stack
        self.dct = self.variables()

    def __enter__(self):
        return

    def __exit__(self, type, value, traceback):
        # compare variables defined at __exist__ with variables defined at __enter__
        dct_exit, dct_enter = self.variables(), self.dct
        for (name, id_) in set(dct_exit).difference(dct_enter):
            self.post_process(name, dct_exit[(name, id_)])

Typical use can be:

# let us define a Variable object that has a 'name' attribute that can be defined at initialisation time or later
class Variable:
    def __init__(self, name=None):
        self.name = name

# the following code
x = Variable('x')
y = Variable('y')
print(x.name, y.name)

# can be replaced by
with VariablePostProcessor(lambda name, obj: setattr(obj, "name", name)):
    x = Variable()
    y = Variable()
print(x.name, y.name)

# in such case, you can also define as a convenience
import functools
AutoRenamer = functools.partial(VariablePostProcessor, post_process=lambda name, obj: setattr(obj, "name", name))

# and rewrite the above code as
with AutoRenamer():
    x = Variable()
    y = Variable()
print(x.name, y.name)  # => x y
Answered By: sdementen
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.