How to merge sub-matrices of high-dimensional matrices under the condition of ensuring the relative position of sub-matrices?

Question:

If I have a tensort x with shape [z, d, d], which indicates a series image frames just like video data. Let pz=z**0.5 and let x = x.view(pz, pz, d, d]. Then we can get a grid of images with grid size of pz*pz, and each image has a shape of [d, d]. Now, I want get a matrix or tensor with shape of [1, 1, p*d, p*d], and MUST insure all element keep the same inter-position with all original images.

For an example:

    x =    [[[ 0,  1],
             [ 2,  3]],

            [[ 4,  5],
             [ 6,  7]],

            [[ 8,  9],
             [10, 11]],
    
            [[12, 13],
             [14, 15]]]

which indicates a series images with shape [2,2] and z = 4
I want get a tensor like:

tensor([[ 0,  1,  4,  5],
        [ 2,  3,  6,  7],
        [ 8,  9, 12, 13],
        [10, 11, 14, 15]])

I can use x = x.view(1, 1, 4, 4) to get one with the same shape,but it likes this:

tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]]]])

which I don’t want.

And more , How about x has more dimension? Just like [b, c, z, d, d]. How to deal with this?

Any suggestion will be helpful.

I have a solution about the three dimention situation.If x.shape = [z, d, d], then the code below will work. But not work for high dimention tensors. Nested loop will be ok, but too heavy.
My solution for three dimention situation:


    d = 2
    z = 4
    b, c = 1, 1
    x = torch.arange(z*d*d).view(z, d, d)
    # x = torch.tensor([[[ 1,  2],
    #          [ 4,  6]],
    #
    #         [[ 8, 10],
    #          [12, 14]],
    #
    #         [[16, 18],
    #          [20, 22]],
    #
    #         [[24, 26],
    #          [28, 30]],
    #
    #         [[32, 34],
    #          [36, 38]],
    #
    #         [[40, 42],
    #          [44, 46]],
    #
    #         [[48, 50],
    #          [52, 54]],
    #
    #         [[56, 58],
    #          [60, 62]],
    #
    #         [[64, 66],
    #          [68, 70]]])
    # make z-index planes to a grid layout
    grid_side_len = int(z**0.5)
    grid_x = x.view(grid_side_len, grid_side_len, d, d)
    # for all rows of crops , horizontally stack them togather
    plane = []
    for i in range(grid_x.shape[0]):
        cat_crops = torch.hstack([crop for crop in grid_x[i]])
        plane.append(cat_crops)

    plane = torch.vstack([p for p in plane])
    print("3D crop to 2D crop plane:")
    print(x)
    print(plane)
    print(plane.shape)


    print("2D crop plane to 3D crop:")
    # group all rows
    split = torch.chunk(plane, plane.shape[1]//d, dim=0)
    spat_flatten = torch.cat([torch.cat(torch.chunk(p, p.shape[1]//d, dim=1), dim=0) for p in     split], dim=0)
    crops = [t[None,:,:] for t in torch.chunk(spat_flatten, spat_flatten.shape[0]//d, dim=0)]
    spat_crops = torch.cat(crops, dim=0)
    print(spat_crops)
    print(spat_crops.shape)
Asked By: Skipper

||

Answers:

This is an operation that can be solved with a combination of torch.transpose and torch.reshape operations. Starting from an arrangement tensor:

>>> x = torch.arange(16).view(4,2,2)
  1. Start by transposing the tensor such that the dimension that you want to collate on is standing "vertically", this can be done with x.transpose(dim0=1, dim1=2). Although, I recommend working with negative dimensions instead:

    >>> x.transpose(-1,-2)
    tensor([[[ 0,  2],
             [ 1,  3]],
    
            [[ 4,  6],
             [ 5,  7]],
    
            [[ 8, 10],
             [ 9, 11]],
    
            [[12, 14],
             [13, 15]]])
    
  2. Then reshape to collate the dimension:

    >>> x.transpose(-1,-2).reshape(2,4,2)
    tensor([[[ 0,  2],
             [ 1,  3],
             [ 4,  6],
             [ 5,  7]],
    
            [[ 8, 10],
             [ 9, 11],
             [12, 14],
             [13, 15]]])
    
  3. Then flip back to recover the order of the elements from step 1.:

    >>> x.transpose(-1,-2).reshape(2,4,2).transpose(-1,-2)
    tensor([[[ 0,  1,  4,  5],
             [ 2,  3,  6,  7]],
    
            [[ 8,  9, 12, 13],
             [10, 11, 14, 15]]])
    
  4. Finally, reshape to the desired form:

    >>> x.transpose(-1,-2).reshape(2,4,2).transpose(-1,-2).reshape(len(x),-1)
    tensor([[ 0,  1,  4,  5],
            [ 2,  3,  6,  7],
            [ 8,  9, 12, 13],
            [10, 11, 14, 15]])
    

From there you can apply to your needs by changing the dimension sizes and even expanding to higher dimension numbers such as [b, c, z, d, d] as you described. If you understand this simple approach by playing around with this example you will be able to work out any problem similar to this.

Answered By: Ivan

Thinks for @Ivan’s answer.
According to @Ivan’s answer, this is exact solution of my question:

b, c, z, d, = 1, 1, 4, 2
pz = int(z**0.5)
crop_img = x.transpose(-1, -2).reshape(b, c, pz, pz*d, d).transpose(-1, -2).reshape(b, c, pz*d, pz*d)

# and the inverse process
x = crop_img.reshape(b, c, pz, d, pz*d).transpose(-1, -2).reshape(b, c, z, d, d).transpose(-1, -2)
Answered By: Skipper