ConnectionPatch for 3D subplots

Question:

Trying to draw a line connecting a point on a 3D subplot to another 3D subplot. In 2D this is easy to do using ConnectionPatch. I’ve tried to mimic the Arrow3D class from here without luck.

I’m happy for even just a work-around at this point. As an example, in the figure generated by the code below I would want to connect the two green dots.

def cylinder(r, n):
    '''
    Returns the unit cylinder that corresponds to the curve r.
    INPUTS:  r - a vector of radii
             n - number of coordinates to return for each element in r

    OUTPUTS: x,y,z - coordinates of points
    '''

    # ensure that r is a column vector
    r = np.atleast_2d(r)
    r_rows, r_cols = r.shape

    if r_cols > r_rows:
        r = r.T

    # find points along x and y axes
    points = np.linspace(0, 2*np.pi, n+1)
    x = np.cos(points)*r
    y = np.sin(points)*r

    # find points along z axis
    rpoints = np.atleast_2d(np.linspace(0, 1, len(r)))
    z = np.ones((1, n+1))*rpoints.T

    return x, y, z


#---------------------------------------
# 3D example
#---------------------------------------
fig = plt.figure()

# top figure
ax = fig.add_subplot(2,1,1, projection='3d')
x,y,z = cylinder(np.linspace(2,1,num=10), 40)
for i in range(len(z)):
    ax.plot(x[i], y[i], z[i], 'c')
ax.plot([2], [0], [0],'go')

# bottom figure
ax2 = fig.add_subplot(2,1,2, projection='3d')
x,y,z = cylinder(np.linspace(0,1,num=10), 40)
for i in range(len(z)):
    ax2.plot(x[i], y[i], z[i], 'r')
ax2.plot([1], [0], [1],'go')

plt.show()
Asked By: benten

||

Answers:

I was trying to solve a very similar problem just tonight! Some of the code may be unnecessary but it will give you the main idea… …I hope

Inspiration from: http://hackmap.blogspot.com.au/2008/06/pylab-matplotlib-imagemap.html
and other many and varied sources over the last two hours…

#! /usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import matplotlib

N = 50
x = np.random.rand(N)
y = np.random.rand(N)
z = np.random.rand(N)

# point's to join
p1 = 10
p2 = 20

fig = plt.figure()

# a background axis to draw lines on
ax0 = plt.axes([0.,0.,1.,1.])
ax0.set_xlim(0,1)
ax0.set_ylim(0,1)

# use these to know how to transform the screen coords
dpi = ax0.figure.get_dpi()
height = ax0.figure.get_figheight() * dpi
width = ax0.figure.get_figwidth() * dpi

# first scatter plot
ax1 = plt.axes([0.05,0.05,0.9,0.425], projection='3d')
ax1.scatter(x, y, z)

# one point of interest
ax1.scatter(x[p1], y[p1], z[p1], s=100.)
x1, y1, _ = proj3d.proj_transform(x[p1], y[p1], z[p1], ax1.get_proj())
[x1,y1] = ax1.transData.transform((x1, y1))  # convert 2d space to screen space
# put them in screen space relative to ax0
x1 = x1/width
y1 = y1/height

# second scatter plot (same data)
ax2 = plt.axes([0.05,0.475,0.9,0.425], projection='3d')
ax2.scatter(x, y, z)

# another point of interest
ax2.scatter(x[p2], y[p2], z[p2], s=100.)
x2, y2, _ = proj3d.proj_transform(x[p2], y[p2], z[p2], ax2.get_proj())
[x2,y2] = ax2.transData.transform((x2, y2))  # convert 2d space to screen space
x2 = x2/width
y2 = y2/height


# set all these guys to invisible (needed?, smartest way?)
for item in [fig, ax1, ax2]:
    item.patch.set_visible(False)

# draw a line between the transformed points
# again, needed? I know it works...

transFigure = fig.transFigure.inverted()

coord1 = transFigure.transform(ax0.transData.transform([x1,y1]))
coord2 = transFigure.transform(ax0.transData.transform([x2,y2]))

line = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]),
                               transform=fig.transFigure)
fig.lines = line,

plt.show()

success

Answered By: minillinim

My final code, just to have a workable example:

#! /usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import matplotlib



def cylinder(r, n):
    '''
    Returns the unit cylinder that corresponds to the curve r.
    INPUTS:  r - a vector of radii
             n - number of coordinates to return for each element in r

    OUTPUTS: x,y,z - coordinates of points
    '''

    # ensure that r is a column vector
    r = np.atleast_2d(r)
    r_rows, r_cols = r.shape

    if r_cols > r_rows:
        r = r.T

    # find points along x and y axes
    points = np.linspace(0, 2*np.pi, n+1)
    x = np.cos(points)*r
    y = np.sin(points)*r

    # find points along z axis
    rpoints = np.atleast_2d(np.linspace(0, 1, len(r)))
    z = np.ones((1, n+1))*rpoints.T

    return x, y, z



#---------------------------------------
# 3D example
#---------------------------------------
fig = plt.figure()

# a background axis to draw lines on
ax0 = plt.axes([0.,0.,1.,1.])
ax0.set_xlim(0,1)
ax0.set_ylim(0,1)

# use these to know how to transform the screen coords
dpi = ax0.figure.get_dpi()
height = ax0.figure.get_figheight() * dpi
width = ax0.figure.get_figwidth() * dpi


# top figure
ax1 = fig.add_subplot(2,1,1, projection='3d')
x,y,z = cylinder(np.linspace(2,1,num=10), 40)
for i in range(len(z)):
    ax1.plot(x[i], y[i], z[i], 'c')


# bottom figure
ax2 = fig.add_subplot(2,1,2, projection='3d')
x,y,z = cylinder(np.linspace(0,1,num=10), 40)
for i in range(len(z)):
    ax2.plot(x[i], y[i], z[i], 'r')


# first point of interest
p1 = ([2],[0],[0])
ax1.plot(p1[0], p1[1], p1[2],'go')
x1, y1, _ = proj3d.proj_transform(p1[0], p1[1], p1[2], ax1.get_proj())
[x1,y1] = ax1.transData.transform((x1[0], y1[0]))  # convert 2d space to screen space
# put them in screen space relative to ax0
x1 = x1/width
y1 = y1/height

# another point of interest
p2 = ([1], [0], [1])
ax2.plot(p2[0], p2[1], p2[2],'go')
x2, y2, _ = proj3d.proj_transform(p2[0], p2[1], p2[2], ax2.get_proj())
[x2,y2] = ax2.transData.transform((x2[0], y2[0]))  # convert 2d space to screen space
x2 = x2/width
y2 = y2/height

# plot line between subplots
transFigure = fig.transFigure.inverted()
coord1 = transFigure.transform(ax0.transData.transform([x1,y1]))
coord2 = transFigure.transform(ax0.transData.transform([x2,y2]))
fig.lines = ax0.plot((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linestyle='dashed' )

plt.show()
Answered By: benten

To fix slight movement of the line
connecting a point,
fig.canvas.draw()
and fig.savefig('…') may work.

In my environment(pydroid),
the coordinate of the dot and the line edge dose not match when displaying the figure with plt.show() maybe because the intractive backend in pydroid automatically changes figure size and then move the line.
So I used fig.savefig('…') instead of using plt.show().

inserting fig.canvas.draw()
(after ax.set_xlim(…,…) etc. and before proj3d.proj_transform) also works.

fig

my code is below.

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import  matplotlib



fig = plt.figure(figsize = (10, 12), dpi=100)

ax0 = plt.axes([0.,0.,1.,1.])
ax0.set_xlim(0,1)
ax0.set_ylim(0,1)


ax0.figure.set_dpi(100)
dpi = ax0.figure.get_dpi()
height = ax0.figure.get_figheight() * dpi
width = ax0.figure.get_figwidth() * dpi


ax1 = fig.add_subplot(2,2,1, projection='3d')
ax2 = fig.add_subplot(2,2,2, projection='3d')


p1 = [-2, 0, 0.5]
ax1.plot(p1[0], p1[1], p1[2],'go')
p2 = [0, 2, 1]
ax2.plot(p2[0], p2[1], p2[2],'go')

ax1.set_xlim(-2,2)
ax1.set_ylim(-2,2)
ax1.set_zlim(0,1)
ax2.set_xlim(-2,2)
ax2.set_ylim(-2,2)
ax2.set_zlim(0,1)



# fig.canvas.draw()



x1, y1, _ = proj3d.proj_transform(p1[0], p1[1], p1[2], ax1.get_proj())
[x1,y1] = ax1.transData.transform((x1, y1))
x1 = x1/width
y1 = y1/height

x2, y2, _ = proj3d.proj_transform(p2[0], p2[1], p2[2], ax2.get_proj())
[x2,y2] = ax2.transData.transform((x2, y2))
x2 = x2/width
y2 = y2/height

transFigure = fig.transFigure.inverted()
coord1 = transFigure.transform(ax0.transData.transform([x1,y1]))
coord2 = transFigure.transform(ax0.transData.transform([x2,y2]))
line1 = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linestyle='dashed' )


ax3 = fig.add_subplot(2,2,3, projection='3d')
ax4 = fig.add_subplot(2,2,4, projection='3d')

ax3.plot(p1[0], p1[1], p1[2],'go')
ax4.plot(p2[0], p2[1], p2[2],'go')

ax3.set_xlim(-2,2)
ax3.set_ylim(-2,2)
ax3.set_zlim(0,1)
ax4.set_xlim(-2,2)
ax4.set_ylim(-2,2)
ax4.set_zlim(0,1)



fig.canvas.draw() 



x1, y1, _ = proj3d.proj_transform(p1[0], p1[1], p1[2], ax3.get_proj())
[x1,y1] = ax3.transData.transform((x1, y1))
x1 = x1/width
y1 = y1/height

x2, y2, _ = proj3d.proj_transform(p2[0], p2[1], p2[2], ax4.get_proj())
[x2,y2] = ax4.transData.transform((x2, y2))
x2 = x2/width
y2 = y2/height

transFigure = fig.transFigure.inverted()
coord1 = transFigure.transform(ax0.transData.transform([x1,y1]))
coord2 = transFigure.transform(ax0.transData.transform([x2,y2]))

line2= matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linestyle='dashed' )


fig.lines = line1, line2


ax0.text(0.2, 0.88, "Not good.", fontsize=30)
ax0.text(0.2, 0.44, "Good!", fontsize=30)

plt.savefig("fig.png",dpi=100)

Answered By: Yusuke Yoshihara
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.