Plotly Dash change networkx node colours in based on user input?

Question:

After creating the minimal working example below, I tried to change the color of the nodes in the graphs based on user input. Specifically, I have n lists of colors (one color per node), and I would like the user to be able to loop (ideally forward and backwards) through the node color lists. (In essence I show firing neurons using the color).

MWE

"""Generates a graph in dash."""


import dash
import dash_core_components as dcc
import dash_html_components as html
import networkx as nx
import plotly.graph_objs as go

# Create graph G
G = nx.DiGraph()
G.add_nodes_from([0, 1, 2])
G.add_edges_from(
    [
        (0, 1),
        (0, 2),
    ],
    weight=6,
)

# Create a x,y position for each node
pos = {
    0: [0, 0],
    1: [1, 2],
    2: [2, 0],
}
# Set the position attribute with the created positions.
for node in G.nodes:
    G.nodes[node]["pos"] = list(pos[node])

# add color to node points
colour_set_I = ["rgb(31, 119, 180)", "rgb(255, 127, 14)", "rgb(44, 160, 44)"]
colour_set_II = ["rgb(10, 20, 30)", "rgb(255, 255, 0)", "rgb(0, 255, 255)"]

# Create nodes
node_trace = go.Scatter(
    x=[],
    y=[],
    text=[],
    mode="markers",
    hoverinfo="text",
    marker=dict(size=30, color=colour_set_I),
)

for node in G.nodes():
    x, y = G.nodes[node]["pos"]
    node_trace["x"] += tuple([x])
    node_trace["y"] += tuple([y])


# Create Edges
edge_trace = go.Scatter(
    x=[],
    y=[],
    line=dict(width=0.5, color="#888"),
    hoverinfo="none",
    mode="lines",
)

for edge in G.edges():
    x0, y0 = G.nodes[edge[0]]["pos"]
    x1, y1 = G.nodes[edge[1]]["pos"]
    edge_trace["x"] += tuple([x0, x1, None])
    edge_trace["y"] += tuple([y0, y1, None])

################### START OF DASH APP ###################
app = dash.Dash()


fig = go.Figure(
    data=[edge_trace, node_trace],
    layout=go.Layout(
        xaxis=dict(showgrid=True, zeroline=True, showticklabels=True),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    ),
)

app.layout = html.Div(
    [
        html.Div(dcc.Graph(id="Graph", figure=fig)),
    ]
)

if __name__ == "__main__":
    app.run_server(debug=True)

When you save this as graph.py and run it with: python graph.py, you can open a browser, go to 127.0.0.1:8050 and see:
enter image description here
For colour_set_I, and:
enter image description here
For colour_set_II.

Question

How can I get a slider e.g. from 0 to n or a next> and back< button that loads the next/previous colour list into the nodes?

Hacky

I noticed when I change the graph.py file line:

marker=dict(size=30, color=colour_set_I),

to:

marker=dict(size=30, color=colour_set_II),

it automatically updates the node colors, however, typing the frame index into the colour set and pressing ctrl+s is somewhat elaborate, even though I thoroughly enjoy keyboard control over clicking.

Asked By: a.t.

||

Answers:

You can use the dash_core_components.Slider component to add a slider to your app that allows the user to change the color of the nodes.You would need to add a callback function that updates the marker color based on the slider value, and use a state variable to keep track of the current color set being displayed.

from dash.dependencies import Input, Output

app = dash.Dash(__name__)

# State variable to keep track of current color set
color_set_index = 0
color_sets = [colour_set_I, colour_set_II]

app.layout = html.Div(
    [
        dcc.Slider(
            id='color-set-slider',
            min=0,
            max=len(color_sets) - 1,
            value=0,
            marks={i: str(i) for i in range(len(color_sets))},
            step=None
        ),
        html.Div(dcc.Graph(id="Graph", figure=fig))
    ]
)

@app.callback(
    Output("Graph", "figure"),
    [Input("color-set-slider", "value")]
)
def update_color(color_set_index):
    fig.data[1]["marker"]["color"] = color_sets[color_set_index]
    return fig

if __name__ == "__main__":
    app.run_server(debug=True)
    
Answered By: Yacine

Here is an example that also updates the directed edge colours for a directed graph. Note it uses annotations as arrows instead of actual lines.

"""Generates a graph in dash."""

from typing import List, Tuple

import dash
import networkx as nx
import plotly
import plotly.graph_objs as go
from dash import dcc, html
from dash.dependencies import Input, Output

# Create graph G
G = nx.DiGraph()
G.add_nodes_from([0, 1, 2])
G.add_edges_from(
    [
        (0, 1),
        (0, 2),
    ],
    weight=6,
)

# Create a x,y position for each node
pos = {
    0: [0, 0],
    1: [1, 2],
    2: [2, 0],
}
# Set the position attribute with the created positions.
for node in G.nodes:
    G.nodes[node]["pos"] = list(pos[node])

# add color to node points
colour_set_I = ["rgb(31, 119, 180)", "rgb(255, 127, 14)", "rgb(44, 160, 44)"]
colour_set_II = ["rgb(10, 20, 30)", "rgb(255, 255, 0)", "rgb(0, 255, 255)"]

# Create edge colour lists.
# TODO: make edge colour function of edges


def set_edge_colours() -> List[List[str]]:
    """(Manually) set edge colours into list."""
    hardcoded_edge_colours = [
        [
            "rgb(31, 119, 180)",  # edge(0,1)
            "rgb(255, 127, 14)",  # edge(0,2)
        ],
        [
            "rgb(10, 20, 30)",  # edge(0,1)
            "rgb(255, 255, 0)",  # edge(0,2)
        ],
    ]
    return hardcoded_edge_colours


def get_edge_colour(
    t: int, edge: Tuple[int, int], edge_colours: List[List[str]]
) -> str:
    """Returns an edge colour based on edge
    TODO: support duplicate edges between nodes."""
    if edge == (0, 1):
        return edge_colours[t][0]
    if edge == (0, 2):
        return edge_colours[t][1]
    raise ValueError(f"Error, edge{edge} not found.")


# Load edge colour list.
edge_colours: List[List[str]] = set_edge_colours()

# Create nodes
node_trace = go.Scatter(
    x=[],
    y=[],
    text=[],
    mode="markers",
    hoverinfo="text",
    marker=dict(size=30, color=colour_set_I),
)

for node in G.nodes():
    x, y = G.nodes[node]["pos"]
    node_trace["x"] += tuple([x])
    node_trace["y"] += tuple([y])

# Create figure
fig = go.Figure(
    # data=[edge_trace, node_trace],
    data=[node_trace],
    layout=go.Layout(
        height=700,  # height of image in pixels.
        width=1000,  # Width of image in pixels.
        annotations=[
            dict(
                ax=G.nodes[edge[0]]["pos"][0],  # starting x.
                ay=G.nodes[edge[0]]["pos"][1],  # starting y.
                axref="x",
                ayref="y",
                x=G.nodes[edge[1]]["pos"][0],  # ending x.
                y=G.nodes[edge[1]]["pos"][1],  # ending y.
                xref="x",
                yref="y",
                arrowwidth=5,  # Width of arrow.
                arrowcolor="red",  # Overwrite in update/using user input.
                arrowsize=0.8,  # (1 gives head 3 times as wide as arrow line)
                showarrow=True,
                arrowhead=1,  # the arrowshape (index).
                hoverlabel=plotly.graph_objs.layout.annotation.Hoverlabel(
                    bordercolor="red"
                ),
                hovertext="sometext",
                text="sometext",
                # textangle=-45,
                # xanchor='center',
                # xanchor='right',
                # swag=120,
            )
            for edge in G.edges()
        ],
    ),
)


# Start Dash app.
app = dash.Dash(__name__)


@app.callback(Output("Graph", "figure"), [Input("color-set-slider", "value")])
def update_color(color_set_index: int) -> go.Figure:
    """Updates the colour of the nodes and edges based on user input."""
    # Update the annotation colour.
    def annotation_colour(
        some_val: int,
        edge_colours: List[List[str]],
        edge: Tuple[int, int],
    ) -> str:
        """Updates the colour of the edges based on user input."""
        return get_edge_colour(
            t=some_val, edge=edge, edge_colours=edge_colours
        )

    # Overwrite annotation with function instead of value.
    for i, edge in enumerate(G.edges()):

        some_annotation_colour = annotation_colour(
            color_set_index, edge_colours=edge_colours, edge=edge
        )
        print(f"edge={edge}")
        print(f"some_annotation_colour={some_annotation_colour}")
        fig.layout.annotations[i].arrowcolor = some_annotation_colour

    # update the node colour
    fig.data[0]["marker"]["color"] = color_sets[color_set_index]  # nodes
    return fig


# State variable to keep track of current color set
initial_color_set_index = 0
color_sets = [colour_set_I, colour_set_II]
fig = update_color(initial_color_set_index)

app.layout = html.Div(
    [
        dcc.Slider(
            id="color-set-slider",
            min=0,
            max=len(color_sets) - 1,
            value=0,
            marks={i: str(i) for i in range(len(color_sets))},
            step=None,
        ),
        html.Div(dcc.Graph(id="Graph", figure=fig)),
    ]
)

if __name__ == "__main__":
    app.run_server(debug=True)
    

Yields:
enter image description here

And here is a more elaborate example, which allows for rotated edge labels, even after stretching the image, as well as recursive edges with labels (both edges are updated based on user selection):

"""Generates a graph in dash."""

from typing import Dict, List, Tuple

import dash
import networkx as nx
import numpy as np
import plotly
import plotly.graph_objs as go
from dash import dcc, html
from dash.dependencies import Input, Output

pixel_width = 1000
pixel_height = 1000
recursive_edge_radius = 0.1
# Create graph G
G = nx.DiGraph()
G.add_nodes_from([0, 1, 2])
G.add_edges_from(
    [
        (0, 1),
        (0, 2),
    ],
    weight=6,
)

# Create a x,y position for each node
pos = {
    0: [0, 0],
    1: [1, 2],
    2: [2, 0],
}
# Set the position attribute with the created positions.
for node in G.nodes:
    G.nodes[node]["pos"] = list(pos[node])

# add color to node points
colour_set_I = ["rgb(31, 119, 180)", "rgb(255, 127, 14)", "rgb(44, 160, 44)"]
colour_set_II = ["rgb(10, 20, 30)", "rgb(255, 255, 0)", "rgb(0, 255, 255)"]

# Create edge colour lists.
# TODO: make edge colour function of edges


def set_edge_colours() -> List[List[str]]:
    """(Manually) set edge colours into list."""
    hardcoded_edge_colours = [
        [
            "rgb(31, 119, 180)",  # edge(0,1)
            "rgb(255, 127, 14)",  # edge(0,2)
        ],
        [
            "rgb(10, 20, 30)",  # edge(0,1)
            "rgb(255, 255, 0)",  # edge(0,2)
        ],
    ]
    return hardcoded_edge_colours


def get_edge_colour(
    t: int,
    edge: Tuple[int, int],
    edge_colours: List[List[str]],  # pylint: disable = W0621
) -> str:
    """Returns an edge colour based on edge
    TODO: support duplicate edges between nodes."""
    if edge == (0, 1):
        return edge_colours[t][0]
    if edge == (0, 2):
        return edge_colours[t][1]
    raise ValueError(f"Error, edge{edge} not found.")


# pylint: disable = W0621
def get_edge_arrows(G: nx.DiGraph) -> List[Dict]:
    """Returns the annotation dictionaries representing the directed edge
    arrows."""
    annotations: List[Dict] = []
    for edge in G.edges:
        # Get coordinates
        left_node_name = edge[0]
        right_node_name = edge[1]
        left_x = G.nodes[left_node_name]["pos"][0]
        left_y = G.nodes[left_node_name]["pos"][1]
        right_x = G.nodes[right_node_name]["pos"][0]
        right_y = G.nodes[right_node_name]["pos"][1]

        # Add annotation.
        annotations.append(
            dict(
                ax=left_x,
                ay=left_y,
                axref="x",
                ayref="y",
                x=right_x,
                y=right_y,
                xref="x",
                yref="y",
                arrowwidth=5,  # Width of arrow.
                arrowcolor="red",  # Overwrite in update/using user input.
                arrowsize=0.8,  # (1 gives head 3 times as wide as arrow line)
                showarrow=True,
                arrowhead=1,  # the arrowshape (index).
                hoverlabel=plotly.graph_objs.layout.annotation.Hoverlabel(
                    bordercolor="red"
                ),
                hovertext="sometext",
                text="sometext",
                # textangle=-45,
                # xanchor='center',
                # xanchor='right',
                # swag=120,
            )
        )
    return annotations


# pylint: disable = W0621
def get_edge_labels(G: nx.DiGraph) -> List[Dict]:
    """Returns the annotation dictionaries representing the labels of the
    directed edge arrows."""
    annotations = []
    for edge in G.edges:
        mid_x, mid_y = get_edge_mid_point(edge)
        annotations.append(
            go.layout.Annotation(
                x=mid_x,
                y=mid_y,
                xref="x",
                yref="y",
                text="dict Text",
                align="center",
                showarrow=False,
                yanchor="bottom",
                textangle=get_stretched_edge_angle(
                    edge, pixel_height=pixel_height, pixel_width=pixel_width
                ),
            )
        )
    return annotations


# pylint: disable = W0621
def get_recursive_edge_labels(G: nx.DiGraph, radius: float) -> List[Dict]:
    """Returns the annotation dictionaries representing the labels of the
    recursive edge circles above the nodes. Note, only place 1 radius above
    pos, because recursive edge circles are.

    actually ovals with height: radius, width:2 * radius.
    """
    annotations = []
    for node in G.nodes:
        x, y = G.nodes[node]["pos"]
        annotations.append(
            go.layout.Annotation(
                x=x,
                y=y + 1 * radius,
                xref="x",
                yref="y",
                text="recur",
                align="center",
                showarrow=False,
                yanchor="bottom",
            )
        )
    return annotations


def get_edge_mid_point(
    edge: Tuple[Tuple[int, int], Tuple[int, int]]
) -> Tuple[int, int]:
    """Returns the mid point of an edge."""
    left_node_name = edge[0]
    right_node_name = edge[1]
    left_x = G.nodes[left_node_name]["pos"][0]
    left_y = G.nodes[left_node_name]["pos"][1]
    right_x = G.nodes[right_node_name]["pos"][0]
    right_y = G.nodes[right_node_name]["pos"][1]
    mid_x = (right_x + left_x) / 2
    mid_y = (right_y + left_y) / 2
    return mid_x, mid_y


def get_stretched_edge_angle(
    edge: Tuple[Tuple[int, int], Tuple[int, int]],
    pixel_height: int,
    pixel_width: int,
) -> Tuple[int, int]:
    """Returns the ccw+ mid point of an edge and adjusts for stretching of the
    image."""
    left_node_name = edge[0]
    right_node_name = edge[1]
    left_x = G.nodes[left_node_name]["pos"][0]
    left_y = G.nodes[left_node_name]["pos"][1]
    right_x = G.nodes[right_node_name]["pos"][0]
    right_y = G.nodes[right_node_name]["pos"][1]
    dx = (right_x - left_x) * (
        1 - ((pixel_height - pixel_width) / pixel_height)
    )
    # dx =
    dy = right_y - left_y
    angle = np.arctan2(dy, dx)
    # return -np.rad2deg((angle) % (2 * np.pi))
    return -np.rad2deg(angle)


def get_pure_edge_angle(
    edge: Tuple[Tuple[int, int], Tuple[int, int]]
) -> Tuple[int, int]:
    """Returns the ccw+ mid point of an edge."""
    left_node_name = edge[0]
    right_node_name = edge[1]
    left_x = G.nodes[left_node_name]["pos"][0]
    left_y = G.nodes[left_node_name]["pos"][1]
    right_x = G.nodes[right_node_name]["pos"][0]
    right_y = G.nodes[right_node_name]["pos"][1]
    dx = right_x - left_x
    dy = right_y - left_y
    angle = np.arctan2(dy, dx)
    # return -np.rad2deg((angle) % (2 * np.pi))
    return -np.rad2deg(angle)


def get_annotations(G: nx.DiGraph) -> List[Dict]:
    """Returns the annotations for this graph."""
    annotations = []
    annotations.extend(get_edge_arrows(G))
    annotations.extend(get_edge_labels(G))
    annotations.extend(
        get_recursive_edge_labels(G, radius=recursive_edge_radius)
    )

    return annotations


def add_recursive_edges(G: nx.DiGraph, fig: go.Figure, radius: float) -> None:
    """Adds a circle, representing a recursive edge, above a node.

    The circle line/edge colour is updated along with the node colour.
    """
    for node in G.nodes:
        x, y = G.nodes[node]["pos"]
        # Add circles
        fig.add_shape(
            type="circle",
            xref="x",
            yref="y",
            x0=x - radius,
            y0=y,
            x1=x + radius,
            y1=y + radius,
            line_color=colour_set_I[node],
        )


# Load edge colour list.
edge_colours: List[List[str]] = set_edge_colours()


# Create nodes
node_trace = go.Scatter(
    x=[],
    y=[],
    text=[],
    mode="markers",
    hoverinfo="text",
    marker=dict(size=30, color=colour_set_I),
)

for node in G.nodes():
    x, y = G.nodes[node]["pos"]
    node_trace["x"] += tuple([x])
    node_trace["y"] += tuple([y])

# Create figure
fig = go.Figure(
    # data=[edge_trace, node_trace],
    data=[node_trace],
    layout=go.Layout(
        height=pixel_height,  # height of image in pixels.
        width=pixel_width,  # Width of image in pixels.
        annotations=get_annotations(G),
    ),
)
add_recursive_edges(G=G, fig=fig, radius=recursive_edge_radius)

# Start Dash app.
app = dash.Dash(__name__)


@app.callback(Output("Graph", "figure"), [Input("color-set-slider", "value")])
def update_color(color_set_index: int) -> go.Figure:
    """Updates the colour of the nodes and edges based on user input."""
    # Update the annotation colour.
    def annotation_colour(
        some_val: int,
        edge_colours: List[List[str]],
        edge: Tuple[int, int],
    ) -> str:
        """Updates the colour of the edges based on user input."""
        return get_edge_colour(
            t=some_val, edge=edge, edge_colours=edge_colours
        )

    # Overwrite annotation with function instead of value.
    for i, edge in enumerate(G.edges()):

        some_annotation_colour = annotation_colour(
            color_set_index, edge_colours=edge_colours, edge=edge
        )
        fig.layout.annotations[i].arrowcolor = some_annotation_colour

    # Update the node colour.
    fig.data[0]["marker"]["color"] = color_sets[color_set_index]  # nodes

    # Update the recursive edge node colour.
    for node_name in G.nodes:
        fig.layout.shapes[node_name]["line"]["color"] = color_sets[
            color_set_index
        ][node_name]

    return fig


# State variable to keep track of current color set
initial_color_set_index = 0
color_sets = [colour_set_I, colour_set_II]
fig = update_color(initial_color_set_index)

app.layout = html.Div(
    [
        dcc.Slider(
            id="color-set-slider",
            min=0,
            max=len(color_sets) - 1,
            value=0,
            marks={i: str(i) for i in range(len(color_sets))},
            step=None,
        ),
        html.Div(dcc.Graph(id="Graph", figure=fig)),
    ]
)

if __name__ == "__main__":
    app.run_server(debug=True)

Yields:
enter image description here

Answered By: a.t.