How to plot only lower half of scatterplot matrix using plotly figure_factory?

Question:

i saw this example from ploty figure_factory

import plotly.graph_objects as go
import plotly.figure_factory as ff

import numpy as np
import pandas as pd

df = pd.DataFrame(np.random.randn(20, 4),
                columns=['Column A', 'Column B', 'Column C', 'Column D'])

df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
                                'grape', 'pear', 'pear', 'apple', 'pear',
                                'apple', 'apple', 'grape', 'apple', 'apple',
                                'grape', 'pear', 'pear', 'apple', 'pear'])


fig = ff.create_scatterplotmatrix(df, diag='histogram',index='Fruit',
                                  height=800, width=800)
fig.show()

And there is the plot:
enter image description here

But i want to know if it possible to plot just the lowerhalft because the upperhalft is the same.
That is possible with lotly.graph_objects (thanks to showupperhalf=False : for example

import plotly.graph_objects as go
import pandas as pd

df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')
index_vals = df['class'].astype('category').cat.codes

fig = go.Figure(data=go.Splom(
                dimensions=[dict(label='sepal length',
                                 values=df['sepal length']),
                            dict(label='sepal width',
                                 values=df['sepal width']),
                            dict(label='petal length',
                                 values=df['petal length']),
                            dict(label='petal width',
                                 values=df['petal width'])],
                showupperhalf=False, # remove plots on diagonal
                text=df['class'],
                marker=dict(color=index_vals,
                            showscale=False, # colors encode categorical variables
                            line_color='white', line_width=0.5)
                ))


fig.update_layout(
    title='Iris Data set',
    width=600,
    height=600,
)

fig.show()

And there is the result:
enter image description here

It to possible to have the same with figure_factory?
Thanks in advance

Or it possible to change the scatterplot from go.Splom to a histogram?or box?

Asked By: Minecraft_Json

||

Answers:

Short answer:

You can blank out any subplot you’d like by replacing all traces within that subplot with an empty trace object like go.Scatter()


The details:

Scatterplot Matrix does not seem to have an option for histograms on the diagonal. And ff.create_scatterplotmatrix does not seem to have an option for hiding parts of the matrix. That does not mean that what you’re trying to do here is impossible though.

The complete code snippet below uses a matrix…

shows = array([[1., 0., 0., 0.],
              [1., 1., 0., 0.],
              [1., 1., 1., 0.],
              [1., 1., 1., 1.]])

…to decide whether or not the data for each subplot of your scatterplot matrix should be shown. If value = 0, then a particular trace within a subplot…

Histogram({
    'marker': {'color': 'rgb(31, 119, 180)'},
    'showlegend': False,
    'x': [-1.4052399956918005, -1.8538677019305498, -0.0016298185761457061,
          -0.5268239747464603, -0.10652357762295094, 0.02921346151566477,
          -0.4581921321144815, 0.020997665069043978, 0.8734380864952216,
          -1.3008441288553083],
    'xaxis': 'x',
    'yaxis': 'y'
})

… is replaced by simply:

go.Scatter()

And if this happens for all traces in a subplot, then the plot is blanked out like you can see here:

Plot:

enter image description here

Complete code:

import plotly.graph_objects as go
import plotly.figure_factory as ff
import numpy as np
import pandas as pd
import plotly.express as px

cols = ['Column A', 'Column B', 'Column C', 'Column D']
# cols = ['Column A', 'Column B', 'Column C']
# cols = ['Column A', 'Column B', 'Column C', 'Column D', 'Column E', 'Column F']
df = pd.DataFrame(np.random.randn(20, len(cols)),
                columns=cols)

df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
                                'grape', 'pear', 'pear', 'apple', 'pear',
                                'apple', 'apple', 'grape', 'apple', 'apple',
                                'grape', 'pear', 'pear', 'apple', 'pear'])

fig = ff.create_scatterplotmatrix(df, diag='histogram', index='Fruit',
                                  height=800, width=800)


# get number of rows and columns
ix = df['Fruit'].unique()
xs = []
ys = []
fig.for_each_xaxis(lambda x: xs.append(x.title.text) if x.title.text != None  else ())
fig.for_each_yaxis(lambda y: ys.append(y.title.text) if y.title.text != None  else ())

# build matrix to determine visiblity of trqces
m = (len(xs), len(ys))
show = np.tril(np.ones(m))
newshows = []
newdata = []
for i in range(0, len(show[0])):
    lst = list(np.repeat(show[i], len(ix)))
    newshows.extend(lst)

# replace existing data with empty data in the upper triangle
for t, s in enumerate(newshows):
    if  newshows[t] == 0:
        newdata.append(go.Scatter())
    else:
        newdata.append(fig.data[t])
        
fig.data = []
for i, new in enumerate(newdata):
    fig.add_trace(newdata[i])

fig.show()
Answered By: vestland
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.