Increased amount of memory usage using Kera's fit() method

Question:

Using TF 2.11.0 with a GPU on Colab.

I am getting an increased amount of system memory used per batch when let the fit() method run (This code only checks per epoch).
This is a very basic CycleGAN class:

import psutil
import gc
from keras.callbacks import Callback

class MemoryUsageCallback(Callback):
  '''Monitor memory usage on epoch begin and end, collect garbage'''

  def on_epoch_begin(self,epoch,logs=None):
    print('**Epoch {}**'.format(epoch))
    print('Memory usage on epoch begin: {}'.format(psutil.Process(os.getpid()).memory_info().rss))

  def on_epoch_end(self,epoch,logs=None):
    gc.collect()
    tf.keras.backend.clear_session()


class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # Pass the images through the gens
            fake_monet = self.m_gen(real_photo, training=True)
            fake_monet_resized = tf.image.resize(fake_monet, [256, 256])
            cycled_photo = self.p_gen(fake_monet_resized, training=True)

            fake_photo = self.p_gen(real_monet, training=True)
            fake_photo_resized =  tf.image.resize(fake_photo , [256, 256])
            cycled_monet = self.m_gen(fake_photo_resized, training=True)

            # resize original images for disc
            real_monet = tf.image.resize(real_monet, [320, 320])
            real_photo = tf.image.resize(real_photo, [320, 320])

            # Calculate discriminators answers
            disc_real_monet = tf.reduce_mean(self.m_disc(real_monet, training=True), axis=[1,2])
            disc_real_photo = tf.reduce_mean(self.p_disc(real_photo, training=True), axis=[1,2])
            disc_fake_monet = tf.reduce_mean(self.m_disc(fake_monet, training=True), axis=[1,2])
            disc_fake_photo = tf.reduce_mean(self.p_disc(fake_photo, training=True), axis=[1,2])

            # Calculate cycle loss
            cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = self.gen_loss_fn(disc_fake_monet) + cycle_loss
            total_photo_gen_loss = self.gen_loss_fn(disc_fake_photo) + cycle_loss

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator 
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)
        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)
        
        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))
        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))
        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

And I train the model like this:

cycle_gan_model = CycleGan(
      monet_generator, photo_generator, monet_discriminator, photo_discriminator)
cycle_gan_model.compile(
    m_gen_optimizer = monet_generator_optimizer,
    p_gen_optimizer = photo_generator_optimizer,
    m_disc_optimizer = monet_discriminator_optimizer,
    p_disc_optimizer = photo_discriminator_optimizer,
    gen_loss_fn = generator_loss,
    disc_loss_fn = discriminator_loss,
    cycle_loss_fn = calc_cycle_loss,
    )
callbacks = [MemoryUsageCallback()]
zipped_dataset = tf.data.Dataset.zip((monet_dataset, photos_dataset))
cycle_gan_model.fit(
    zipped_dataset,
    epochs=20,
    steps_per_epoch=150,
    callbacks=callbacks
)

Note that the dataset I am using is one from a generators I defined earlier which I am not so sure work well >.<

The memory usage in the start of the epochs is constantly increasing:

3195469824->3974574080->4476375040->4954685440->…7341445120->…

it doesn’t seem to stop.

I tried:

  1. Changing my activation functions from ReLU to LeakyReLU.
  2. Checking that run_eagerly=true
  3. Putting my activation functions in serperate layers
  4. Adding a custom callback that collects garbage and clears the Keras backend at the end of each epoch.
  5. Lowering the batch size (right now it’s 2)
  6. Playing around with the layers of the models
  7. Changing the optimizers to ADAFactor as it is more memeory efficient

I expected the memory usage to stabalize after the first epoch was over, but instead of it kept increasing.

—UPDATE—

Changed the train_step to

    def train_step(self, batch_data):
        return {
            "monet_gen_loss": 0,
            "photo_gen_loss": 0,
            "monet_disc_loss": 0,
            "photo_disc_loss": 0
        }

and yet I am still seeing the memory increase so I am suspecting it might be the datasets.

I saved the datasets on my machine as image files then used these methods to make the datasets:

def image_loader(folder_path, batch_size):
    while True:
        folder_path = Path(folder_path)
        image_files = [file for file in folder_path.iterdir() if file.suffix == ".jpg"]        # Shuffle the list of image files
        random.shuffle(image_files)
        for i in range(0, len(image_files), batch_size):
            batch_files = image_files[i:i+batch_size]
            images = [np.array(Image.open(str(file))) for file in batch_files]
            yield np.array(images) 


BATCH_SIZE = 2 
photos_loader = image_loader('/content/photos/',BATCH_SIZE)


def augment_element(image):
  image = tf.convert_to_tensor(image)

  # flips and rotations
  image = tf.image.random_flip_left_right(image)
  image = tf.image.rot90(image, k=random.randint(0,3))

  # Crops the image randomly and resizes it back
  size = tf.random.uniform(shape=[], minval=100, maxval=256, dtype=tf.int32)
  image = tf.image.random_crop(image, size=[size, size, 3])
  image = tf.image.resize(image, [256, 256])

  return image.numpy()



def augmented_image_loader(folder_path, batch_size):
    while True:
        folder_path = Path(folder_path)
        image_files = [file for file in folder_path.iterdir() if file.suffix == ".jpg"]
        images = [np.array(Image.open(str(file))) for file in image_files]
        while True:
            random_images = random.choices(images, k=batch_size)
            augmented_images = [augment_element(image) for image in random_images]
            yield np.array(augmented_images)



augmented_monet_loader = augmented_image_loader('/content/monet/',BATCH_SIZE)

and finally

monet_dataset = tf.data.Dataset.from_generator(
        lambda: augmented_image_loader('/content/monet/', BATCH_SIZE),
        output_types=tf.float32,
        output_shapes=(tf.TensorShape([BATCH_SIZE, None, None, 3]))
    )
photos_dataset = tf.data.Dataset.from_generator(
        lambda: image_loader('/content/photos/', BATCH_SIZE),
        output_types=tf.float32,
        output_shapes=(tf.TensorShape([BATCH_SIZE, None, None, 3]))
    )


def normalize(image):
    image = image - 127.5
    image = image / 127.5
    return image


monet_dataset = monet_dataset.cache().map(normalize)
photos_dataset = photos_dataset.cache().map(normalize)

but then again using the following code doesn’t seem to increase my memory usage (except for the first iteration):

for i in range(1000):
  data = next(iter(zipped_dataset))

It’s my first time working with tf datasets (part of my uni’s assignment) so I am a complete rookie in that department, any help would be appriciated! 😀

Asked By: River

||

Answers:

I noticed you wrote:

monet_dataset = monet_dataset.cache().map(normalize)
photos_dataset = photos_dataset.cache().map(normalize)

It’s important to note that the cache() method loads the entire dataset into memory, so it’s not suitable for very large datasets. It is best used when the dataset can fit into memory, when the elements are expensive to generate and you expect to iterate over the same dataset multiple times.

Replace it with

monet_dataset = monet_dataset.map(normalize)
photos_dataset = photos_dataset.map(normalize)

And it should solve the memory issue.

Answered By: IdanSi