How to use a linear index to access a 2D array in Python

Question:

I have one piece of code in MATLAB, and I try to translate that code to Python. In MATLAB, I can write this:

x = [1,2,3;4,5,6;7,8,9];

which is just a 3×3 matrix. Then if I use x(1:5), MATLAB will first transfer the matrix x into a 1×9 vector and then return me a 1×5 vector as follows: ans=[1,4,7,2,5];

So could you please tell me that what piece of simple code in Python can have that same outcome?

Asked By: user7466160

||

Answers:

I’m not sure what MATLAB’s x(1:5) syntax is supposed to do but, according to your desired output, it would seem to be transposing the matrix, flattening it, and then returning a slice. This is how to do that in Python:

>>> from itertools import chain
>>>
>>> x = [[1,2,3],
...      [4,5,6],
...      [7,8,9]]
>>>
>>> list(chain(*zip(*x)))[0:5]
[1, 4, 7, 2, 5]
Answered By: Dan

You can convert your matrix to a numpy array and then use unravel_index to convert your linear indices into subscripts which you can then use to index into your original matrix. Note that all commands below use the 'F' input to use column-major ordering (the default for MATLAB) rather than row-major ordering (the default for numpy)

import numpy as np

a = np.array([[1,2,3],[4,5,6],[7,8,9]])
inds = np.arange(5);

result = a[np.unravel_index(inds, a.shape, 'F')]
#   array([1, 4, 7, 2, 5])

Also, if you want to flatten a matrix like MATLAB you can do that as well:

a.flatten('F')
#   array([1, 4, 7, 2, 5, 8, 3, 6, 9])

If you are converting a bunch of MATLAB code to python, it’s strong recommended to use numpy and look at the documentation on notable differences

Answered By: Suever

An alternative way to directly access the 2D array without making a transformed copy is to use the integer division and the modulo operators.

import numpy as np

# example array
rect_arr = np.array([[1, 2, 3, 10], [4, 5, 6, 11], [7, 8, 9, 12]])
rows, cols = rect_arr.shape

print("Array is:n", rect_arr)
print(f"rows = {rows}, cols = {cols}")

# Access by Linear Indexing
# Reference:
# https://upload.wikimedia.org/wikipedia/commons/4/4d/Row_and_column_major_order.svg

total_elems = rect_arr.size

# Row major order
print("nRow Major Sequence:")
for linear_index in range(total_elems):
    # do something with rect_arr[linear_index // cols][linear_index % cols]
    # Sequence will be 1, 2, 3, 10, 4, 5, 6, 11, 7, 8, 9, 12
    print(rect_arr[linear_index // cols][linear_index % cols])

# Columnn major order
print("nColumn Major Sequence:")
for linear_index in range(total_elems):
    # do something with rect_arr[linear_index % rows][linear_index // rows]
    # Sequence will be 1, 4, 7, 2, 5, 8, 3, 6, 9, 10, 11, 12
    print(rect_arr[linear_index % rows][linear_index // rows])


# With unravel_index
# Row major order
row_indices = range(total_elems)
row_transformed_arr = rect_arr[np.unravel_index(row_indices, rect_arr.shape, "C")]
print(row_transformed_arr)

# Columnn major order
col_indices = range(total_elems)
col_transformed_arr = rect_arr[np.unravel_index(row_indices, rect_arr.shape, "F")]
print(col_transformed_arr)

Useful for plotting in subplots:

# <df> is a date-indexed dataframe with 8 columns containing time-series data
fig, axs = plt.subplots(nrows=4, ncols=2)
rows, cols = axs.shape

# Order plots in row-major
for i, colname in enumerate(df):
    df[colname].plot(ax=axs[i // cols][i % cols], title=colname)
plt.show()

# Order plots in column-major
for i, colname in enumerate(df):
    df[colname].plot(ax=axs[i % rows][i // rows], title=colname)
plt.show()
Answered By: Manoj Baishya
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.