TensorFlow: Remember LSTM state for next batch (stateful LSTM)

Question:

Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1 in the example below. After each timestep the internal LSTM (memory and hidden) states need to be remembered for the next ‘batch’. For the very beginning of the inference the internal LSTM states init_c, init_h are computed given the input. These are then stored in a LSTMStateTuple object which is passed to the LSTM. During training this state is updated every timestep. However for inference I want the state to be saved in between batches, i.e. the initial states only need to be computed at the very beginning and after that the LSTM states should be saved after each ‘batch’ (n=1).

I found this related StackOverflow question: Tensorflow, best way to save state in RNNs?. However this only works if state_is_tuple=False, but this behavior is soon to be deprecated by TensorFlow (see rnn_cell.py). Keras seems to have a nice wrapper to make stateful LSTMs possible but I don’t know the best way to achieve this in TensorFlow. This issue on the TensorFlow GitHub is also related to my question: https://github.com/tensorflow/tensorflow/issues/2838

Anyone good suggestions for building a stateful LSTM model?

inputs  = tf.placeholder(tf.float32, shape=[None, seq_length, 84, 84], name="inputs")
targets = tf.placeholder(tf.float32, shape=[None, seq_length], name="targets")

num_lstm_layers = 2

with tf.variable_scope("LSTM") as scope:

    lstm_cell  = tf.nn.rnn_cell.LSTMCell(512, initializer=initializer, state_is_tuple=True)
    self.lstm  = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_lstm_layers, state_is_tuple=True)

    init_c = # compute initial LSTM memory state using contents in placeholder 'inputs'
    init_h = # compute initial LSTM hidden state using contents in placeholder 'inputs'
    self.state = [tf.nn.rnn_cell.LSTMStateTuple(init_c, init_h)] * num_lstm_layers

    outputs = []

    for step in range(seq_length):

        if step != 0:
            scope.reuse_variables()

        # CNN features, as input for LSTM
        x_t = # ... 

        # LSTM step through time
        output, self.state = self.lstm(x_t, self.state)
        outputs.append(output)
Asked By: verified.human

||

Answers:

Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.

with tf.variable_scope('decoder') as scope:
    rnn_cell = tf.nn.rnn_cell.MultiRNNCell 
    ([
        tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),
        tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)
    ], state_is_tuple = True)

    state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]

    for t in range(TIME_STEPS):
        if t:
            last = y_[t - 1] if TRAINING else y[t - 1]
        else:
            last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))

        y[t] = tf.concat(1, (y[t], last))
        y[t], state = rnn_cell(y[t], state)

        scope.reuse_variables()

Rather than using tf.nn.rnn_cell.LSTMStateTuple I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.

Answered By: chasep255

I found out it was easiest to save the whole state for all layers in a placeholder.

init_state = np.zeros((num_layers, 2, batch_size, state_size))

...

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])

Then unpack it and create a tuple of LSTMStateTuples before using the native tensorflow RNN Api.

l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
 for idx in range(num_layers)]
)

RNN passes in the API:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)

The state – variable will then be feeded to the next batch as a placeholder.

Answered By: user1506145