Plot many PartialDependencePlot lines in one plot for multiclass classification

Question:

Kind of a broad question but I need to plot many PartialDependencePlot lines in the same plot – one line for each target in the multiclass classification, for each variable in the dataset. So for variable age I’d have one plot with the many PDP lines, one for each target (I have 10), and so on for the rest of the variables.

There seems to be no way to do this on the sklearn.inspection.PartialDependenceDisplay method, and I’ve tried messing with sklearn.inspection.partial_dependence and got so far as to get this, but I don’t really know where to go from here:

pd =partial_dependence(xgb_clf, X_test, features=['age', 'score1', 'score2'],  kind="average", grid_resolution=5)
pd 

{'average': array([[[[0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
          [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
          [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
          [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
          [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ]],


     [[0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
      [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
      [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
      [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ],
      [0.811337  , 0.811337  , 0.811337  , 0.811337  , 0.811337  ]],

     [[0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 ],
      [0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 ],
      [0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 ],
      [0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 ],
      [0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 , 0.8237547 ]],

     [[0.82299083, 0.82299083, 0.82299083, 0.82299083, 0.82299083],
      [0.82299083, 0.82299083, 0.82299083, 0.82299083, 0.82299083],
      [0.82299083, 0.82299083, 0.82299083, 0.82299083, 0.82299083],
      [0.82299083, 0.82299083, 0.82299083, 0.82299083, 0.82299083],
      [0.82299083, 0.82299083, 0.82299083, 0.82299083, 0.82299083]],

     [[0.82412416, 0.82412416, 0.82412416, 0.82412416, 0.82412416],
      [0.82412416, 0.82412416, 0.82412416, 0.82412416, 0.82412416],
      [0.82412416, 0.82412416, 0.82412416, 0.82412416, 0.82412416],
      [0.82412416, 0.82412416, 0.82412416, 0.82412416, 0.82412416],
      [0.82412416, 0.82412416, 0.82412416, 0.82412416, 0.82412416]]],


    [[[0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061]],

     [[0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061],
      [0.01702061, 0.01702061, 0.01702061, 0.01702061, 0.01702061]],

     [[0.01730013, 0.01730013, 0.01730013, 0.01730013, 0.01730013],
      [0.01730013, 0.01730013, 0.01730013, 0.01730013, 0.01730013],
      [0.01730013, 0.01730013, 0.01730013, 0.01730013, 0.01730013],
      [0.01730013, 0.01730013, 0.01730013, 0.01730013, 0.01730013],
      [0.01730013, 0.01730013, 0.01730013, 0.01730013, 0.01730013]],

     [[0.01728426, 0.01728426, 0.01728426, 0.01728426, 0.01728426],
      [0.01728426, 0.01728426, 0.01728426, 0.01728426, 0.01728426],
      [0.01728426, 0.01728426, 0.01728426, 0.01728426, 0.01728426],
      [0.01728426, 0.01728426, 0.01728426, 0.01728426, 0.01728426],
      [0.01728426, 0.01728426, 0.01728426, 0.01728426, 0.01728426]],

     [[0.01731277, 0.01731277, 0.01731277, 0.01731277, 0.01731277],
      [0.01731277, 0.01731277, 0.01731277, 0.01731277, 0.01731277],
      [0.01731277, 0.01731277, 0.01731277, 0.01731277, 0.01731277],
      [0.01731277, 0.01731277, 0.01731277, 0.01731277, 0.01731277],
      [0.01731277, 0.01731277, 0.01731277, 0.01731277, 0.01731277]]],


    [[[0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252]],

     [[0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252],
      [0.00188252, 0.00188252, 0.00188252, 0.00188252, 0.00188252]],

     [[0.00202412, 0.00202412, 0.00202412, 0.00202412, 0.00202412],
      [0.00202412, 0.00202412, 0.00202412, 0.00202412, 0.00202412],
      [0.00202412, 0.00202412, 0.00202412, 0.00202412, 0.00202412],
      [0.00202412, 0.00202412, 0.00202412, 0.00202412, 0.00202412],
      [0.00202412, 0.00202412, 0.00202412, 0.00202412, 0.00202412]],

     [[0.00294247, 0.00294247, 0.00294247, 0.00294247, 0.00294247],
      [0.00294247, 0.00294247, 0.00294247, 0.00294247, 0.00294247],
      [0.00294247, 0.00294247, 0.00294247, 0.00294247, 0.00294247],
      [0.00294247, 0.00294247, 0.00294247, 0.00294247, 0.00294247],
      [0.00294247, 0.00294247, 0.00294247, 0.00294247, 0.00294247]],

     [[0.00294639, 0.00294639, 0.00294639, 0.00294639, 0.00294639],
      [0.00294639, 0.00294639, 0.00294639, 0.00294639, 0.00294639],
      [0.00294639, 0.00294639, 0.00294639, 0.00294639, 0.00294639],
      [0.00294639, 0.00294639, 0.00294639, 0.00294639, 0.00294639],
      [0.00294639, 0.00294639, 0.00294639, 0.00294639, 0.00294639]]],


    ...,


    [[[0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533]],

     [[0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533],
      [0.08890533, 0.08890533, 0.08890533, 0.08890533, 0.08890533]],

     [[0.07579581, 0.07579581, 0.07579581, 0.07579581, 0.07579581],
      [0.07579581, 0.07579581, 0.07579581, 0.07579581, 0.07579581],
      [0.07579581, 0.07579581, 0.07579581, 0.07579581, 0.07579581],
      [0.07579581, 0.07579581, 0.07579581, 0.07579581, 0.07579581],
      [0.07579581, 0.07579581, 0.07579581, 0.07579581, 0.07579581]],

     [[0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 ],
      [0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 ],
      [0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 ],
      [0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 ],
      [0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 , 0.0757297 ]],

     [[0.07584671, 0.07584671, 0.07584671, 0.07584671, 0.07584671],
      [0.07584671, 0.07584671, 0.07584671, 0.07584671, 0.07584671],
      [0.07584671, 0.07584671, 0.07584671, 0.07584671, 0.07584671],
      [0.07584671, 0.07584671, 0.07584671, 0.07584671, 0.07584671],
      [0.07584671, 0.07584671, 0.07584671, 0.07584671, 0.07584671]]],


    [[[0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371]],

     [[0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371],
      [0.00334371, 0.00334371, 0.00334371, 0.00334371, 0.00334371]],

     [[0.00339652, 0.00339652, 0.00339652, 0.00339652, 0.00339652],
      [0.00339652, 0.00339652, 0.00339652, 0.00339652, 0.00339652],
      [0.00339652, 0.00339652, 0.00339652, 0.00339652, 0.00339652],
      [0.00339652, 0.00339652, 0.00339652, 0.00339652, 0.00339652],
      [0.00339652, 0.00339652, 0.00339652, 0.00339652, 0.00339652]],

     [[0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 ],
      [0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 ],
      [0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 ],
      [0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 ],
      [0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 , 0.0033935 ]],

     [[0.00339899, 0.00339899, 0.00339899, 0.00339899, 0.00339899],
      [0.00339899, 0.00339899, 0.00339899, 0.00339899, 0.00339899],
      [0.00339899, 0.00339899, 0.00339899, 0.00339899, 0.00339899],
      [0.00339899, 0.00339899, 0.00339899, 0.00339899, 0.00339899],
      [0.00339899, 0.00339899, 0.00339899, 0.00339899, 0.00339899]]],


    [[[0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438]],

     [[0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438],
      [0.00560438, 0.00560438, 0.00560438, 0.00560438, 0.00560438]],

     [[0.00569604, 0.00569604, 0.00569604, 0.00569604, 0.00569604],
      [0.00569604, 0.00569604, 0.00569604, 0.00569604, 0.00569604],
      [0.00569604, 0.00569604, 0.00569604, 0.00569604, 0.00569604],
      [0.00569604, 0.00569604, 0.00569604, 0.00569604, 0.00569604],
      [0.00569604, 0.00569604, 0.00569604, 0.00569604, 0.00569604]],

     [[0.00569026, 0.00569026, 0.00569026, 0.00569026, 0.00569026],
      [0.00569026, 0.00569026, 0.00569026, 0.00569026, 0.00569026],
      [0.00569026, 0.00569026, 0.00569026, 0.00569026, 0.00569026],
      [0.00569026, 0.00569026, 0.00569026, 0.00569026, 0.00569026],
      [0.00569026, 0.00569026, 0.00569026, 0.00569026, 0.00569026]],

     [[0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 ],
      [0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 ],
      [0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 ],
      [0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 ],
      [0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 , 0.0056994 ]]]],
   dtype=float32),


'values': [array([21.  , 30.25, 39.5 , 48.75, 58.  ]),
  array([403.91 , 434.205, 464.5  , 494.795, 525.09 ]),
  array([nan, nan, nan, nan, nan])]}

Not very hopeful but has if anyone has done anything similar I’d appreciate the help.

Asked By: amestrian

||

Answers:

Since no one proposed an answer and I found it, I’ll explain to close the thread.

First of all it’s better to take the partial_dependence for each variable individually like so

pd_results=partial_dependence(xgb_clf, X_test, features=['age'],  kind="average", grid_resolution=10)
pd_results

{'average': array([[8.6179382e-01, 8.6179382e-01, 8.6179382e-01, 8.7144512e-01,
         8.7144512e-01, 8.7203729e-01, 8.7203729e-01, 8.7216240e-01,
         8.7216240e-01, 8.7216240e-01],
        [1.1322678e-02, 1.1322678e-02, 1.1322678e-02, 1.1472423e-02,
         1.1472423e-02, 1.1484091e-02, 1.1484091e-02, 1.1490046e-02,
         1.1490046e-02, 1.1490046e-02],
        [1.5599880e-03, 1.5599880e-03, 1.5599880e-03, 1.6714102e-03,
         1.6714102e-03, 1.6726020e-03, 1.6726020e-03, 2.4344050e-03,
         2.4344050e-03, 2.4344050e-03],
        [1.2759878e-03, 1.2759878e-03, 1.2759878e-03, 1.2905410e-03,
         1.2905410e-03, 1.2914637e-03, 1.2914637e-03, 1.2915849e-03,
         1.2915849e-03, 1.2915849e-03],
        [4.4725675e-02, 4.4725675e-02, 4.4725675e-02, 4.5341633e-02,
         4.5341633e-02, 4.4625707e-02, 4.4625707e-02, 4.3705072e-02,
         4.3705072e-02, 4.3705072e-02],
        [8.4771524e-04, 8.4771524e-04, 8.4771524e-04, 8.5622317e-04,
         8.5622317e-04, 9.0603990e-04, 9.0603990e-04, 9.0637564e-04,
         9.0637564e-04, 9.0637564e-04],
        [2.1257352e-03, 2.1257352e-03, 2.1257352e-03, 2.1444152e-03,
         2.1444152e-03, 2.1457684e-03, 2.1457684e-03, 2.1463418e-03,
         2.1463418e-03, 2.1463418e-03],
        [7.0548818e-02, 7.0548818e-02, 7.0548818e-02, 5.9901968e-02,
         5.9901968e-02, 5.9954390e-02, 5.9954390e-02, 5.9978280e-02,
         5.9978280e-02, 5.9978280e-02],
        [2.8337168e-03, 2.8337168e-03, 2.8337168e-03, 2.8685690e-03,
         2.8685690e-03, 2.8712715e-03, 2.8712715e-03, 2.8726130e-03,
         2.8726130e-03, 2.8726130e-03],
        [2.9649595e-03, 2.9649595e-03, 2.9649595e-03, 3.0075535e-03,
         3.0075535e-03, 3.0111559e-03, 3.0111559e-03, 3.0130986e-03,
         3.0130986e-03, 3.0130986e-03]], dtype=float32),
 'values': [array([21.        , 25.11111111, 29.22222222, 33.33333333, 37.44444444,
         41.55555556, 45.66666667, 49.77777778, 53.88888889, 58.        ])]}

where values are the x-axis ticks, and 'average' is a (10,10) np.array, where each sub-array contains the Partial dependence points at each x-axis tick, for each target. I have 10 targets so there’s 10 sub-arrays.

From here it’s pretty simple, just plot each line individually using the number as points

fig, ax = plt.subplots(figsize=(5,3))
plt.plot(pd_results['values'][0],pd_results['average'][0], label='target0')
plt.plot(pd_results['values'][0],pd_results['average'][1], label='target1')
plt.plot(pd_results['values'][0],pd_results['average'][2], label='target2')
plt.plot(pd_results['values'][0],pd_results['average'][3], label='target3')
plt.plot(pd_results['values'][0],pd_results['average'][4], label='target4')
plt.plot(pd_results['values'][0],pd_results['average'][5], label='target5')
plt.plot(pd_results['values'][0],pd_results['average'][6], label='target6')
plt.plot(pd_results['values'][0],pd_results['average'][7], label='target7')
plt.plot(pd_results['values'][0],pd_results['average'][8], label='target8')
plt.plot(pd_results['values'][0],pd_results['average'][9], label='target9')
ax.set_yscale('log')
plt.title('age')
plt.show()

If you want to do a stackplot it’s actually easier because it doesn’t require to plot each line separately, like so:

pd_results=partial_dependence(xgb_clf, X_test, features=['age'],  kind="average", grid_resolution=30)
plt.stackplot(pd_results['values'][0], pd_results['average'], labels=reasons)
plt.title('Age - stackplot')

I wanted to show the output but I can’t post images in answers yet.

Answered By: amestrian