Why does Pytorch uses so much memory despite the tensor not being that big and not requiring gradients?

Question:

I’m having some unexpected out of memory issues when running a script locally that uses torch 1.13.1+cpu.

Below is a minimal example that reproduces the issue. Why does it takes 10GB of RAM and does not liberate it until the script ends? The acutal size of the tensor is quite small. What is taking that much memory and how can I prevent it / garbage collect it?

# memory_issue.py

import torch
from memory_profiler import profile
import sys

@profile
def get_random_tensor():
    t = torch.randn([3, 429, 1080, 1920], dtype=torch.float32, requires_grad=False)
    print(f"Allocated memory for tensor: {sys.getsizeof(t)} Bytes")
    return t

@profile
def permute_tensor(t):
    t = torch.permute(t, (3, 0, 1, 2))
    return t

@profile
def do_nothing(t):    
    return t

if __name__ == '__main__':
    with torch.no_grad():
        my_tensor = get_random_tensor()
        my_tensor = permute_tensor(my_tensor)
        my_tensor = do_nothing(my_tensor)

When I run:

python memory_issue.py

This is the output:

Allocated memory for tensor: 72 Bytes

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
     5    197.4 MiB    197.4 MiB           1   @profile
     6                                         def get_random_tensor():
     7  10379.1 MiB  10181.8 MiB           1       t = torch.randn([3, 429, 1080, 1920], dtype=torch.float32, requires_grad=False)
     8  10379.1 MiB      0.0 MiB           1       print(f"Allocated memory for tensor: {sys.getsizeof(t)} Bytes")
     9  10379.1 MiB      0.0 MiB           1       return t


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    11  10379.1 MiB  10379.1 MiB           1   @profile
    12                                         def permute_tensor(t):
    13  10380.0 MiB      0.8 MiB           1       t = torch.permute(t, (3, 0, 1, 2))
    14  10380.0 MiB      0.0 MiB           1       return t


Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    16  10380.0 MiB  10380.0 MiB           1   @profile
    17                                         def do_nothing(t):    
    18  10380.0 MiB      0.0 MiB           1       return t

Thanks!

Asked By: RR_28023

||

Answers:

Your description is not right. The tensor you created is a really big tensor.

torch.randn([3, 429, 1080, 1920], dtype=torch.float32, requires_grad=False)
means it occupies 3x429x1080x1920x4(float32 occupy 4 byte) = 10674892800 Bytes = 10GB memory space

Answered By: nnzzll
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.