How to plot a multidimensional array in a 3D plot?

Question:

I have a multidimensional array (P x N x M) and I want to plot each N x M array in a 3D plot in such a way the P images are stacked along the z-axis.

Do you know hw to do this in Python?

Thanks in advance

Asked By: Krystal

||

Answers:

If you want N x M arrays as "heatmaps" stacked above each other along the z-axis this is one way to do it:

import numpy as np
import matplotlib.pyplot as plt

# Generate some dummy arrays
P, N, M = 5, 10, 10
data = np.random.rand(P, N, M)

# Create a 3D figure
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Create meshgrid for x, y values
x, y = np.meshgrid(np.arange(M), np.arange(N))

# Plot each N x M array as a heatmap at different heights along the z-axis
for p in range(P):
    heatmap = data[p]
    ax.plot_surface(x, y, np.full_like(heatmap, p), facecolors=plt.cm.viridis(heatmap), rstride=1, cstride=1, antialiased=True, shade=False)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('P')
ax.set_title('Stacked Heatmaps')
plt.show()

Result:

enter image description here

Answered By: Geom

You can achieve this by using Matplotlib’s Axes3D module. This code will generate a 3D scatter plot where each 2D slice from the P x N x M array is stacked along the z-axis at a different height (controlled by the z variable). The color of each point in the scatter plot represents the values in the corresponding slice, and a colorbar is added to indicate the data values.

import numpy as np
import matplotlib.pyplot as plt

# Example multidimensional array of shape (P, N, M)
# Replace this with your actual data
P, N, M = 5, 10, 10
data = np.random.rand(P, N, M)

# Create a 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Create a meshgrid for the x and y values
x, y = np.meshgrid(range(N), range(M))

for p in range(P):
  # Flatten the 2D slice and stack it at the height of p
  z = np.full((N, M), p)
  ax.scatter(x, y, z, c=data[p].ravel(), cmap='viridis')

# Set labels for each axis
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

# Customize the colorbar
norm = plt.Normalize(data.min(), data.max())
sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
sm.set_array([])
fig.colorbar(sm, label='Data Values')

plt.show()

Result:
enter image description here

Answered By: Aditya Tiwari