How to zoom in a graph in Python using matplotlib or plotly?

Question:

I want to zoom in to the red section of the graph. I have the following code snippet. What changes do I need to make in my code either using matplotlib or plotly to achieve the desired results. Please prefer plotly library if you can.

# Visualize decision tree predictions

predictions = treePrediction
valid = df[x.shape[0]:]
valid["Predictions"] = predictions
plt.figure(figsize=(12, 7))
plt.title("Apple's Stock Price Prediction Model(Decision Tree Regressor Model)")
plt.xlabel("Days")
plt.ylabel("Close Price USD ($)")
plt.plot(df["Mean"])
plt.plot(valid[["Mean", "Predictions"]])
plt.legend(["Original", "Valid", "Predictions"])
plt.show()

enter image description here

Asked By: Anwesa Roy

||

Answers:

Referring to the published node book, I have reproduced the graph in question. Since the nodebook was using locale CSV data, I retrieved the stock price data from yfinance and replaced the closing price with df['Mean'].
Plotly, a graph_object was used to add each graph. I have also added a button to select the period for zooming. See this page for details. We are adding 3 months as a period selection.

import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(go.Scatter(mode='lines', x=df.index, y=df['Mean'], line_color='blue', name='Original'))
fig.add_trace(go.Scatter(mode='lines', x=valid.index, y=valid['Mean'], line_color='orange', name='Valid'))
fig.add_trace(go.Scatter(mode='lines', x=valid.index, y=valid['Predictions'], line_color='green',name='Predictions'))

fig.update_layout(
    autosize=True,
    height=600,
    title="Apple's Stock Price Prediction Model(Decision Tree Regressor Model)",
    xaxis_title="Days",
    yaxis_title="Close Price USD ($)",
    template='plotly_white'
)

# Add range slider
fig.update_layout(
    xaxis=dict(
        rangeselector=dict(
            buttons=list([
                dict(count=1,
                     label="1m",
                     step="month",
                     stepmode="backward"),
                dict(count=3,
                     label="3m",
                     step="month",
                     stepmode="backward"),              
                dict(count=6,
                     label="6m",
                     step="month",
                     stepmode="backward"),
                dict(count=1,
                     label="YTD",
                     step="year",
                     stepmode="todate"),
                dict(count=1,
                     label="1y",
                     step="year",
                     stepmode="backward"),
                dict(step="all")
            ])
        ),
        rangeslider=dict(
            visible=True
        ),
        type="date"
    )
)

fig.show()

enter image description here

Select the most recent month using the Select Period button.

enter image description here

The range slider at the bottom of the graph can be used to select any range.

Answered By: r-beginners