Pytorch Python Distributed Multiprocessing: Gather/Concatenate tensor arrays of different lengths/sizes
Question:
If you have tensor arrays of different lengths across several gpu ranks, the default all_gather
method does not work as it requires the lengths to be same.
For example, if you have:
if gpu == 0:
q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
q = torch.tensor([5.3], device=torch.device(gpu))
If I need to gather these two tensor arrays as follows:
all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])
the default torch.all_gather
does not work as the lengths, 2, 1
are different.
Answers:
As it is not directly possible to gather using built in methods, we need to write custom function with the following steps:
- Use
dist.all_gather
to get sizes of all arrays.
- Find the max size.
- Pad local array to max size using zeros/constants.
- Use
dist.all_gather
to get all padded arrays.
- Unpad the added zeros/constants using sizes found in step 1.
The below function does this:
def all_gather(q, ws, device):
"""
Gathers tensor arrays of different lengths across multiple gpus
Parameters
----------
q : tensor array
ws : world size
device : current gpu device
Returns
-------
all_q : list of gathered tensor arrays from all the gpus
"""
local_size = torch.tensor(q.size(), device=device)
all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
dist.all_gather(all_sizes, local_size)
max_size = max(all_sizes)
size_diff = max_size.item() - local_size.item()
if size_diff:
padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
q = torch.cat((q, padding))
all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
dist.all_gather(all_qs_padded, q)
all_qs = []
for q, size in zip(all_qs_padded, all_sizes):
all_qs.append(q[:size])
return all_qs
Once, we are able to do the above, we can then easily use torch.cat
to further concatenate into a single array if needed:
torch.cat(all_q)
[torch.tensor([1.5, 2.3, 5.3])
Adapted from: github
Here is an extension of @omsrisagar’s solution that supports tensors of any number of dimensions (not only 1-dimensional tensors).
def all_gather_nd(tensor):
"""
Gathers tensor arrays of different lengths in a list.
The length dimension is 0. This supports any number of extra dimensions in the tensors.
All the other dimensions should be equal between the tensors.
Args:
tensor (Tensor): Tensor to be broadcast from current process.
Returns:
(Tensor): output list of tensors that can be of different sizes
"""
world_size = dist.get_world_size()
local_size = torch.tensor(tensor.size(), device=tensor.device)
all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(all_sizes, local_size)
max_length = max(size[0] for size in all_sizes)
length_diff = max_length.item() - local_size[0].item()
if length_diff:
pad_size = (length_diff, *tensor.size()[1:])
padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype)
tensor = torch.cat((tensor, padding))
all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(all_tensors_padded, tensor)
all_tensors = []
for tensor_, size in zip(all_tensors_padded, all_sizes):
all_tensors.append(tensor_[:size[0]])
return all_tensors
Note that this requires that all the tensors have the same number of dimensions and have all their dimensions equal, except for the first dimension.
If you have tensor arrays of different lengths across several gpu ranks, the default all_gather
method does not work as it requires the lengths to be same.
For example, if you have:
if gpu == 0:
q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
q = torch.tensor([5.3], device=torch.device(gpu))
If I need to gather these two tensor arrays as follows:
all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])
the default torch.all_gather
does not work as the lengths, 2, 1
are different.
As it is not directly possible to gather using built in methods, we need to write custom function with the following steps:
- Use
dist.all_gather
to get sizes of all arrays. - Find the max size.
- Pad local array to max size using zeros/constants.
- Use
dist.all_gather
to get all padded arrays. - Unpad the added zeros/constants using sizes found in step 1.
The below function does this:
def all_gather(q, ws, device):
"""
Gathers tensor arrays of different lengths across multiple gpus
Parameters
----------
q : tensor array
ws : world size
device : current gpu device
Returns
-------
all_q : list of gathered tensor arrays from all the gpus
"""
local_size = torch.tensor(q.size(), device=device)
all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
dist.all_gather(all_sizes, local_size)
max_size = max(all_sizes)
size_diff = max_size.item() - local_size.item()
if size_diff:
padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
q = torch.cat((q, padding))
all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
dist.all_gather(all_qs_padded, q)
all_qs = []
for q, size in zip(all_qs_padded, all_sizes):
all_qs.append(q[:size])
return all_qs
Once, we are able to do the above, we can then easily use torch.cat
to further concatenate into a single array if needed:
torch.cat(all_q)
[torch.tensor([1.5, 2.3, 5.3])
Adapted from: github
Here is an extension of @omsrisagar’s solution that supports tensors of any number of dimensions (not only 1-dimensional tensors).
def all_gather_nd(tensor):
"""
Gathers tensor arrays of different lengths in a list.
The length dimension is 0. This supports any number of extra dimensions in the tensors.
All the other dimensions should be equal between the tensors.
Args:
tensor (Tensor): Tensor to be broadcast from current process.
Returns:
(Tensor): output list of tensors that can be of different sizes
"""
world_size = dist.get_world_size()
local_size = torch.tensor(tensor.size(), device=tensor.device)
all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(all_sizes, local_size)
max_length = max(size[0] for size in all_sizes)
length_diff = max_length.item() - local_size[0].item()
if length_diff:
pad_size = (length_diff, *tensor.size()[1:])
padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype)
tensor = torch.cat((tensor, padding))
all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(all_tensors_padded, tensor)
all_tensors = []
for tensor_, size in zip(all_tensors_padded, all_sizes):
all_tensors.append(tensor_[:size[0]])
return all_tensors
Note that this requires that all the tensors have the same number of dimensions and have all their dimensions equal, except for the first dimension.