RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call() while using tf.distribute.MirroredStrategy

Question:

I’m trying to change a model to multiple gpus using mirrored_strategy.
I was able to replicate my issue on a simpler model, which is in
https://colab.research.google.com/drive/16YlKuzdluryVRmcM680tjtLWfPjt5qhS

But here is the important part of the code;

def loss_object(target_y, pred_y):
    pred_ssum = tf.math.reduce_sum(tf.math.square(pred_y))
    target_ssum = tf.math.reduce_sum(tf.math.square(target_y))
    mul_sum = tf.math.reduce_sum(tf.math.multiply(pred_y, target_y))
    return tf.math.divide(-2 * mul_sum, tf.math.add(pred_ssum, target_ssum))

EPOCHS = 10



model = MyModel()

optimizer = tf.keras.optimizers.RMSprop(lr=2e-5)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

@tf.function
def distributed_train_step(images, labels):
    per_replica_losses = mirrored_strategy.experimental_run_v2(train_step, args=(images, labels,))
    return mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                        axis=None)

@tf.function
def distributed_test_step(images, labels):
    return mirrored_strategy.experimental_run_v2(test_step, args=(images, labels,))

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    #train_loss(loss)
    train_accuracy.update_state(labels, predictions)

@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)

for epoch in range(EPOCHS):
# Reset the metrics at the start of the next epoch
    total_loss = 0.0
    num_batches = 0

    for images, labels in train_ds:
        #train_step(images, labels)
        total_loss += distributed_train_step(images, labels)
        num_batches += 1
    train_loss = total_loss/num_batches

    for test_images, test_labels in test_ds:
        #test_step(test_images, test_labels)
        distributed_test_step(test_images, test_labels)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1, train_loss, train_accuracy.result()*100, test_loss.result(), test_accuracy.result()*100))

    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

All the code above is in mirrored_strategy.scope():
The model just simply takes (4,4,4) cube with constant values and goes through two 3D_CNN and 3D_CNN_Transpose layers to get the same (4,4,4) cube as output.

However, I get an error saying

RuntimeError                              Traceback (most recent call last)
<ipython-input-19-93fb783af116> in <module>()
     65         for images, labels in train_ds:
     66             #train_step(images, labels)
---> 67             total_loss += distributed_train_step(images, labels)
     68             num_batches += 1
     69         train_loss = total_loss/num_batches

8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint_disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

RuntimeError: in user code:

    <ipython-input-19-93fb783af116>:32 distributed_train_step  *
        per_replica_losses = mirrored_strategy.experimental_run_v2(train_step, args=(images, labels,))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/mirrored_strategy.py:770 _call_for_each_replica  *
        fn, args, kwargs)
    <ipython-input-19-93fb783af116>:43 train_step  *
        predictions = model(images, training=True)
    <ipython-input-14-cb5f0d1313e2>:9 call  *
        with mirrored_strategy.scope():
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:291 __enter__
        self._context.strategy.extended)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:214 _require_cross_replica_or_default_context_extended
        raise RuntimeError("Method requires being in cross-replica context, use "

    RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()

Has anybody faced a similar problem? It would be grateful if somebody provides me with a suggestion.

Asked By: jongsung park

||

Answers:

As per the discussion, the Model was indeed the cause of this error. Below corrections is the working running code for this problem.

In your dataset changing the datatype of int to float will prevent a future TypeError.

from __future__ import absolute_import, division, print_function, unicode_literals
!pip install tf-nightly
#%tensorflow_version 2.x
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import datasets, layers, models, Model
import numpy as np

mirrored_strategy = tf.distribute.MirroredStrategy()

def train_gen():
    for i in range(10):
      yield tf.constant(i, shape=(4,4,4,1)), tf.constant(i, shape=(4,4,4,1))

def test_gen():
    for i in range(2):
      yield tf.constant(i+10, shape=(4,4,4,1)), tf.constant(i+10, shape=(4,4,4,1))

BATCH_SIZE_PER_REPLICA = 2
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

train_ds = tf.data.Dataset.from_generator(
    train_gen,
    output_types=(tf.float32, tf.float32),    # using float as your datatype
    output_shapes=((4,4,4,1), (4,4,4,1))
)

test_ds = tf.data.Dataset.from_generator(
    test_gen,
    output_types=(tf.float32, tf.float32),      # using float as your datatype
    output_shapes=((4,4,4,1), (4,4,4,1))
)

train_ds = train_ds.batch(GLOBAL_BATCH_SIZE)
test_ds = test_ds.batch(GLOBAL_BATCH_SIZE)

On your model, the mirrored_strategy.scope() causes the error you are encountering. Removing this as the below code, will solve this problem.

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        #with mirrored_strategy.scope():
        self.cnn_down_1 = layers.Conv3D(1, (2, 2, 2), strides=2, padding='same')
        self.cnn_up_1 = layers.Conv3DTranspose(1, (2, 2, 2), strides=2, padding='same')

    def call(self, inputs):
         #with mirrored_strategy.scope():
            x = self.cnn_down_1(inputs)
            return self.cnn_up_1(x) 

assert tf.distribute.get_replica_context() is not None  # default

On the part of code below, removing @tf.function before the train_step and test_step function is necessary.

with mirrored_strategy.scope():
    #assert tf.distribute.get_replica_context() is not None  # default
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True,
    reduction=tf.keras.losses.Reduction.NONE)
    def loss_object(target_y, pred_y):
        pred_ssum = tf.math.reduce_sum(tf.math.square(pred_y))
        target_ssum = tf.math.reduce_sum(tf.math.square(target_y))
        mul_sum = tf.math.reduce_sum(tf.math.multiply(pred_y, target_y))
        return tf.math.divide(-2 * mul_sum, tf.math.add(pred_ssum, target_ssum))

    EPOCHS = 10



    model = MyModel()

    optimizer = tf.keras.optimizers.RMSprop(lr=2e-5)

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

    #@tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = loss_object(labels, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        #train_loss(loss)
        train_accuracy.update_state(labels, predictions)
        return loss

    #@tf.function
    def test_step(images, labels):
        predictions = model(images, training=False)
        t_loss = loss_object(labels, predictions)

        test_loss.update_state(t_loss)
        test_accuracy.update_state(labels, predictions)


    @tf.function
    def distributed_train_step(images, labels):
      assert tf.distribute.get_replica_context() is None
      per_replica_losses = mirrored_strategy.experimental_run_v2(train_step, args=(images, labels,))
      return mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                          axis=None)

    @tf.function
    def distributed_test_step(images, labels):
        return mirrored_strategy.experimental_run_v2(test_step, args=(images, labels,))


    for epoch in range(EPOCHS):
    # Reset the metrics at the start of the next epoch
        #train_loss.reset_states()
        total_loss = 0.0
        num_batches = 0

        for images, labels in train_ds:
            #train_step(images, labels)
            total_loss += distributed_train_step(images, labels)
            num_batches += 1
        train_loss = total_loss/num_batches

        for test_images, test_labels in test_ds:
            #test_step(test_images, test_labels)
            distributed_test_step(test_images, test_labels)

        template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
        print(template.format(epoch+1, train_loss, train_accuracy.result()*100, test_loss.result(), test_accuracy.result()*100))

        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

This solves the problem and properly removed the error. This is now working. Hope this solved the problem.

Answered By: TF_Support

I have also similar error saying
RuntimeError: `apply_gradients() cannot be called in cross-replica context. Use `tf.distribute.Strategy.run` to enter replica context. The error change when I remove @tf.function as @TF_Support said in his answer.

Answered By: stic-lab