Question about filter overlapping for plotly-Dash in callback

Question:

I want to apply multiple filters to a data table in my dash app.

The app is structured in the following way:

  1. Choose a table
  2. Choose a column to add
  3. Add this column.
  4. Choose wanted values (categorical column) or wanted range value (numeric or date column)
  5. Apply filter

And you can go back to 2 to add another column filter.

It works with same type of filter. For example, if you want to filtered data just bewteen categorical data using ‘filter_cat’, filters work sequentially. (Firstly filtered data -> second filter apply -> secondly filtered data)

However, if I want to apply different types of filter (filter for categorical value and filter for numerical value and datetime value), it does not work sequentially.

Currently the following happens:

  1. Firstly filtered data by categorical column filter
  2. Apply numeric column filter, but it does not apply filter on firstly filtered data, but on original data.

Moreover, firstly applied filter does not work anymore after adding second filter.

from dash import Dash,html,dcc,Input,Output,State,Patch,MATCH,ALL,ctx
from dash.exceptions import PreventUpdate
import dash_ag_grid as dag
import pandas as pd
import plotly.express as px
import dash_bootstrap_components as dbc

app = Dash(__name__)
# Sample data for demonstration
data_table1 = pd.DataFrame({
    'Category1': ['A', 'B', 'C', 'A', 'B'],
    'Category2': ['X', 'Y', 'X', 'Y', 'Z'],
    'Numeric1': [10, 15, 8, 12, 6],
    'Numeric2': [100, 200, 150, 50, 300],
    'Date1': pd.to_datetime(['2023-09-01', '2023-09-02', '2023-09-03', '2023-09-04', '2023-09-05']),
    'Date2': pd.to_datetime(['2023-09-01 08:00', '2023-09-02 10:00', '2023-09-03 12:00', '2023-09-04 14:00', '2023-09-05 16:00'])
})


data_table2 = pd.DataFrame({
    'Category3': ['A', 'B', 'C', 'A', 'B'],
    'Category4': ['X', 'Y', 'X', 'Y', 'Z'],
    'Numeric3': [10, 15, 8, 12, 6],
    'Numeric4': [100, 200, 150, 50, 300],
    'Date3': pd.to_datetime(['2023-09-01', '2023-09-02', '2023-09-03', '2023-09-04', '2023-09-05']),
    'Date4': pd.to_datetime(['2023-09-10 08:00', '2023-09-12 10:00', '2023-09-13 12:00', '2023-09-14 14:00', '2023-09-15 16:00'])
})



rowClassRules = {
    # apply green to 2008
"rounded": True,
}
rowStyle={
"border-radius": "10px"
}
defaultColDef = {
    "resizable": True,
    "sortable": True, 
    "filter": True,
    "initialWidth": 200,
    "wrapHeaderText": True,
    "autoHeaderHeight": True,
    "headerClass": 'center-header', "cellStyle": {'textAlign': 'center'}
}


table_configs = {
    "table1": {
        "df": data_table1,
        "columns": data_table1.columns,
    },
    "table2": {
        "df": data_table2,
        "columns": data_table2.columns,
    },
}
def get_selected_dataframe(selected_table):
    if selected_table == "table1":
        return data_table1
    elif selected_table == "table2":
        return data_table2
    else:
        return pd.DataFrame()
    


list_table = ['table1','table2']

dropdown_table = dcc.Dropdown(
            options=[{'label': i, 'value': i} for i in list_table],
            value = 'table1',
            id="filter_table",
            # clearable=False,
            style={"marginBottom": 10},
            multi=False
        )

dropdown_var_filter = dcc.Dropdown(
                        id='filter_variable_to_show',
                        options=[],
                        persistence=True,
                        multi=True,
                        placeholder='Select a table...',)

second_filter = dcc.Dropdown(
                        id='second_filter',
                        options=[],
                        value=[],
                        multi=False,
                        persistence=True,
                        placeholder='Select a columns...',)



table_output = html.Div(id='table_output')

@app.callback(
    Output('update-rowdata-grid', 'rowData'),
    Input('apply_filter_btn','n_clicks'),
    State({'type': 'filter_cat',"table":ALL ,'index': ALL}, 'value'),
    State({'type': 'filter_cat',"table":ALL ,'index': ALL}, 'id'),
    State({'type': 'filter_num','table':ALL, 'index': ALL}, 'value'),
    State({'type': 'filter_num',"table":ALL ,'index': ALL}, 'id'),
    State({'type': 'filter_date','table':ALL, 'index': ALL}, 'start_date'),
    State({'type': 'filter_date','table':ALL, 'index': ALL}, 'end_date'),
    State({'type': 'filter_date','table':ALL, 'index': ALL}, 'id'),
    State('filter_table', 'value'),
    State('second_filter','value'),
    prevent_initial_call=True
    
)
def apply_filter(n_clicks,cat,cat_id,num,num_id,start_date,end_date,date_id,selected_table,selected_columns):
    df = get_selected_dataframe(selected_table)
    dff = df.copy()
    column_type = df[selected_columns].dtype
    if n_clicks > 0 :
        print(n_clicks)
        if column_type == 'object' and cat[0]:
            # Without for, we cannot assign proper unique value to each column.
            # For example, cat could have a [['X'],['A']]. Here, 'X' is from column 1 and 'A' is from column 2
            # To link each unique value to proper column, I should use cat_id, containing information about column
            # And we should iterate it using for loop. dff is updated for each column.
            print('cat_filter')
            for idx,value in enumerate(cat_id):
                dff = dff[dff[value['index']].isin(cat[idx])]
                
        if column_type in ['int64', 'float64'] and num[0]:
            # Same as cat. But it is composed of two element (min & max value). We have to extract both
            print('num_filter')
            for idx,value in enumerate(num_id):
                dff = dff[(dff[value['index']] >= num[idx][0]) & (dff[value['index']] <= num[idx][1])]
                
        if column_type == 'datetime64[ns]' and start_date and end_date:
            # Same as cat and num.
            print('date_filter')
            for idx,value in enumerate(date_id):
                dff = dff[(dff[value['index']] >= start_date[idx]) & (dff[value['index']] <= end_date[idx])]
    return dff.to_dict('records')

@app.callback(
        Output('second_filter', 'options',allow_duplicate=True),
        Input({"type": "filter_column", "index": ALL},'value'),
        Input({"type": "filter_column", "index": ALL},'id'),
        Input('filter_table','value'),
        prevent_initial_call='initial_duplicate'
    )
def update_filter(value,col_id,selected_table):
    df = get_selected_dataframe(selected_table)
    if value :
        return [{"label": col, "value": col} for col in df.drop(columns = list(value),axis=1).columns]
    else :
        return [{"label": col, "value": col} for col in df.columns]



@app.callback(
    Output('filter_container','children',allow_duplicate=True),
    Input('add_filter_btn','n_clicks'),
    State("second_filter", "value"),
    State('filter_table','value'),
    prevent_initial_call='initial_duplicate'
)
def add_filter(n_clicks,selected_columns,selected_table):
    patched_children = Patch()
    df = get_selected_dataframe(selected_table)
    columns = df.columns
    
    if n_clicks != None and selected_columns:
        column_type = df[selected_columns].dtype
        if column_type == 'object':
            unique_values = df[selected_columns].unique()
            new_filter = html.Div([
                    html.H4('Used filter'),
                    dcc.Dropdown(
                        id={"type": "filter_column", "index": n_clicks},
                        value=selected_columns,
                        options=[{"label": col, "value": col} for col in columns],
                        disabled=True,
                    ),
                    dcc.Dropdown(
                    id={"type": "filter_cat", "table": selected_table, "index": selected_columns},
                    options=[{"label": str(val), "value": val} for val in unique_values],
                    placeholder="Select a value",
                    multi=True,
            ),
                    dbc.Row(dbc.Button("X", id={"type": "remove_btn", "table": selected_table, "index": selected_columns}, color="primary"))
            ])
        elif column_type in ['int64', 'float64']:
        # For Integer & Float type, create slider filter
            min_val = df[selected_columns].min()
            max_val = df[selected_columns].max()
            new_filter = html.Div([
                html.H4('Used filter'),
                    dcc.Dropdown(
                        id={"type": "filter_column", "index": n_clicks},
                        value=selected_columns,
                        options=[{"label": col, "value": col} for col in columns],
                        disabled=True,
                    ),
                    dcc.RangeSlider(
                    id={"type": "filter_num", "table": selected_table, "index": selected_columns},
                    min = min_val,
                    max = max_val,
            ),
                    dbc.Row(dbc.Button("X", id={"type": "remove_btn", "table": selected_table, "index": selected_columns}, color="primary"))
            ])
        elif column_type == 'datetime64[ns]':
        # For Integer & Float type, create slider filter
            min_date = df[selected_columns].min()
            max_date = df[selected_columns].max()
            new_filter = html.Div([
                html.H4('Used filter'),
                    dcc.Dropdown(
                        id={"type": "filter_column", "index": n_clicks},
                        value=selected_columns,
                        options=[{"label": col, "value": col} for col in columns],
                        disabled=True,
                    ),
                    dcc.DatePickerRange(
                        id={"type": "filter_date", "table": selected_table, "index": selected_columns},
                        min_date_allowed=min_date,
                        max_date_allowed=max_date,
                        display_format='DD/MM/YYYY',
                        clearable=True,
                        start_date=min_date,
                        end_date=max_date
                        

            ),
                    dbc.Row(dbc.Button("X", id={"type": "remove_btn", "table": selected_table, "index": selected_columns}, color="primary"))
            ])    
        patched_children.append(new_filter)
        return patched_children
    
    return patched_children

# @app.callback(
#     Output('filter_container','children',allow_duplicate=True),
#     Input({"type": "remove_btn", "table": ALL, "index": ALL},'n_clicks'),
#     prevent_initial_call=True
# )
# def remove_param_filter(n_clicks):
#     if n_clicks :
#         return None

@app.callback(
            Output('second_filter', 'options',allow_duplicate=True),
            Output('filter_container','children',allow_duplicate=True),
            Output('update-rowdata-grid', 'rowData',allow_duplicate=True),
             Input('clear-button','n_clicks'),
             State('filter_table', 'value'),
             prevent_initial_call=True)
def reset_filters(n_clicks, selected_table):
    if n_clicks:
        df = get_selected_dataframe(selected_table)
        return [{"label": col, "value": col} for col in df.columns],None,df.to_dict('records')
    else:
        raise PreventUpdate




@app.callback(
    Output('filter_variable_to_show','options'),
    Input('filter_table', 'value'),
)
def filter_col(selected_table):
    df = get_selected_dataframe(selected_table)
    return [{"label": col, "value": col} for col in df.columns]


@app.callback(
        Output('second_filter', 'options'),
        Input('filter_variable_to_show','value'),
        Input('filter_table','value')
    )
def update_filter(value,selected_table):
    df = get_selected_dataframe(selected_table)
    if value :
        return value
    else :
        return [{"label": col, "value": col} for col in df.columns]
    
@app.callback(
        Output('filter_container','children',allow_duplicate=True),
        Output('table_output', 'children'),
        Input('filter_table', 'value'),
        Input('filter_variable_to_show','value'),
        prevent_initial_call='initial_duplicate'
    )
def update_table(value,selected_columns):
    config = table_configs.get(value)
    if config:
        df = config["df"]
        if selected_columns:
            df = df[selected_columns]
        table = dag.AgGrid(
            id = "update-rowdata-grid",
            rowData=df.to_dict('records'),
            defaultColDef=defaultColDef,
            columnDefs=[{'field':i} for i in df.columns],
            columnSize="autoSize",
            dashGridOptions={"pagination": True},
            className="ag-theme-alpine",
            rowClassRules=rowClassRules,
            rowStyle=rowStyle,
        )
        return None,table
filter_container=html.Div(id="filter_container", children=[])


filter=dbc.Card(
    [
        dbc.CardHeader(html.H3("Filter")),
        dbc.CardBody(
            [
                dbc.Row(
                    children=[second_filter,
                            filter_container,
                            html.Hr()],
                    style={"height": "80%"},  # Adjust the height as per your requirement
                )
            ]
        ),
        dbc.CardBody(
            [
                dbc.Row(
                    children=[
                        # Apply button to create filter. After that, I want to create new filter section.
                        dbc.Col(dbc.Button("Add", id="add_filter_btn", color="primary"),width=6,),
                        dbc.Col(dbc.Button("Apply", id="apply_filter_btn", color="primary"),width=6,),
                        # Clear button to remove all filters
                        dbc.Col(dbc.Button("Clear", id="clear-button", color="danger"),width=6,),
                        html.Hr(),
                        
                    ],
                    style={"height": "20%"},  # Adjust the height as per your requirement
                )
            ]
        ),])

app.layout = html.Div(children = [
    html.H3("Table selection"),
    dropdown_table,
    html.Hr(),
    html.H3("Variable To SHOW"),
    dropdown_var_filter,
    html.Hr(),
    filter,
    html.H3("Output Table"),
    table_output])


# graph = dcc.Graph(id="my-graph",figure={})






if __name__ == "__main__":
    app.run_server(debug=True)

The code is quite long because it is entire test app, but you can focus on just two callback functions: ‘add_filter’ and ‘apply_filter’

What I want to do is add overlapping multiple filtesr between different types of filters. So if I apply filter A => firstly filtered data. And after I apply filter B => secondly filtered data from filtered data.

How can I do that ?

Asked By: stat_man

||

Answers:

You don’t need any conditions on selected_columns, which is the column name of the last added filter. Just directly iterate through each type of filter, as you already do:


@app.callback(
    # ...
    State({'type': 'filter_date', 'table': ALL, 'index': ALL}, 'id'),
    State('filter_table', 'value'),
    # second filter removed (last filter column name)
    prevent_initial_call=True)
def apply_filter(n_clicks, cat, cat_id, num, num_id, 
                 start_date, end_date, date_id, selected_table):

    dff = get_selected_dataframe(selected_table).copy()

    if not n_clicks:
        raise PreventUpdate

    for idx, value in enumerate(cat_id):
        dff = dff[dff[value['index']].isin(cat[idx])]

    for idx, value in enumerate(num_id):
        dff = dff[
            (dff[value['index']] >= num[idx][0]) & (
                        dff[value['index']] <= num[idx][1])]

    for idx, value in enumerate(date_id):
        dff = dff[(dff[value['index']] >= start_date[idx]) & (
            dff[value['index']] <= end_date[idx])]

    return dff.to_dict('records')
Answered By: Dmitry
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.