Pytorch Python Distributed Multiprocessing: Gather/Concatenate tensor arrays of different lengths/sizes

Question:

If you have tensor arrays of different lengths across several gpu ranks, the default all_gather method does not work as it requires the lengths to be same.

For example, if you have:

if gpu == 0:
    q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
    q = torch.tensor([5.3], device=torch.device(gpu))

If I need to gather these two tensor arrays as follows:

all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])

the default torch.all_gather does not work as the lengths, 2, 1 are different.

Asked By: omsrisagar

||

Answers:

As it is not directly possible to gather using built in methods, we need to write custom function with the following steps:

  1. Use dist.all_gather to get sizes of all arrays.
  2. Find the max size.
  3. Pad local array to max size using zeros/constants.
  4. Use dist.all_gather to get all padded arrays.
  5. Unpad the added zeros/constants using sizes found in step 1.

The below function does this:

def all_gather(q, ws, device):
    """
    Gathers tensor arrays of different lengths across multiple gpus
    
    Parameters
    ----------
        q : tensor array
        ws : world size
        device : current gpu device
        
    Returns
    -------
        all_q : list of gathered tensor arrays from all the gpus

    """
    local_size = torch.tensor(q.size(), device=device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
    dist.all_gather(all_sizes, local_size)
    max_size = max(all_sizes)

    size_diff = max_size.item() - local_size.item()
    if size_diff:
        padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
        q = torch.cat((q, padding))

    all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
    dist.all_gather(all_qs_padded, q)
    all_qs = []
    for q, size in zip(all_qs_padded, all_sizes):
        all_qs.append(q[:size])
    return all_qs

Once, we are able to do the above, we can then easily use torch.cat to further concatenate into a single array if needed:

torch.cat(all_q)
[torch.tensor([1.5, 2.3, 5.3])

Adapted from: github

Answered By: omsrisagar

Here is an extension of @omsrisagar’s solution that supports tensors of any number of dimensions (not only 1-dimensional tensors).

def all_gather_nd(tensor):
    """
    Gathers tensor arrays of different lengths in a list.
    The length dimension is 0. This supports any number of extra dimensions in the tensors.
    All the other dimensions should be equal between the tensors.

    Args:
        tensor (Tensor): Tensor to be broadcast from current process.

    Returns:
        (Tensor): output list of tensors that can be of different sizes
    """
    world_size = dist.get_world_size()
    local_size = torch.tensor(tensor.size(), device=tensor.device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
    dist.all_gather(all_sizes, local_size)

    max_length = max(size[0] for size in all_sizes)

    length_diff = max_length.item() - local_size[0].item()
    if length_diff:
        pad_size = (length_diff, *tensor.size()[1:])
        padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding))

    all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(all_tensors_padded, tensor)
    all_tensors = []
    for tensor_, size in zip(all_tensors_padded, all_sizes):
        all_tensors.append(tensor_[:size[0]])
    return all_tensors

Note that this requires that all the tensors have the same number of dimensions and have all their dimensions equal, except for the first dimension.

Answered By: jb0u