Why can I update the .data attribute of a pytorch tensor when the variable is outside the local namespace

Question:

I’m able to access and update the .data attribute of a pytorch tensor when the variable is outside a functions namespace:

x = torch.zeros(5)
def my_function():
    x.data += torch.ones(5)
my_function()
print(x)       # tensor([1., 1., 1., 1., 1.])

When I (attempt to) update x in the regular fashion though (i.e. x += y), I get an error "UnboundLocalError: local variable ‘x’ referenced before assignment". This is expected because x is outside of my_function‘s namespace.

x = torch.zeros(5)
def my_function():
    x += torch.ones(5)   # UnboundLocalError: local variable 'x' referenced before assignment
my_function()

Why can I update x via .data but not with its regular += operator?

Asked By: jstm

||

Answers:

This doesn’t have to do with PyTorch specifically. Python assumes any assignment within a local scope refers to a local variable unless the variable is explicitly declared global in that scope. A similar question: Why does this UnboundLocalError occur (closure)?

For your particular question, the problem is that x is defined only in the global scope, so you can’t assign a new value to x without declaring it global. On the other hand, x.data refers to an attribute of x, the attribute itself is not a global, so you can assign it without using the global keyword.

As an example, consider the following code

class Foo():
    def __init__(self):
        self.data = 1

x = Foo()

def f():
    x.data += 1

f()
print(x.data)  # 2

This code will update x.data as expected since x.data is not a global variable.

On the other hand

class Foo():
    def __init__(self):
        self.data
    def __iadd__(self, v)
        self.data += v
        return self

x = Foo()

def f():
    x += 1    # UnboundLocalError

f()
print(x.data)

will raise an UnboundLocalError because x += 1 is interpreted by the python compiler as an assignment to x, therefore x must refer to a local variable. Since a local x hasn’t been declared prior to this you get an exception.

In order for the previous code to work we need to explicitly declare x to be global within the function’s scope.

class Foo():
    def __init__(self):
        self.data
    def __iadd__(self, v)
        self.data += v
        return self

x = Foo()

def f():
    global x   # tell python that x refers to a global variable
    x += 1

f()
print(x.data)  # 2
Answered By: jodag

You actually can. I suppose reason is in how pytorch processes math assignment operations. It will not create new local variable, but will modify object provided as a variable. You just need to provide that object as a variable to your function. But in my opinion this approach contradicts python rules and shouldn’t be used.

>>> def fn(x):
        x+=1

>>> a = 0
>>> fn(a)
>>> a
0
>>> a = torch.tensor([0.])
>>> a
tensor([0.])
>>> fn(a)
>>> a
tensor([1.])
Answered By: Sergey Skrebnev
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.