Tensorflow model not training if dense layer is used as starter layer

Question:

I am facing a weird problem. I am training my TF model using custom training loops. If I use only dense layers as my 1st layer, the model does not seem to train (I am using flattened MNIST dataset
to train). If I use use a flattened layer on top of my already flattend dataset, the model training seems to be working fine.

Note – The reason why I am using a flattened layer on an already flattend dataset is to show that only using dense layer as the first layer is not working. If I use a conv2d layer on non flattened data, then also model trains perfectly. The issue seems to lie on the dense layers for some reason.

Can’t seem to find the issue.

Tensorflow version – 2.9.1

Python version – 3.8.6

Model that works

class CustomModel(keras.Model):
    def __init__(self, num_classes, name = None):
        super().__init__(name = name)

        self._flatten = tf.keras.layers.Flatten()
        self._dense1 = tf.keras.layers.Dense(64)
        self._dense2 = tf.keras.layers.Dense(num_classes)

    @tf.function
    def call(self, X, training=False):
        X = self._flatten(X)
        X = tf.nn.relu(self._dense1(X))
        return self._dense2(X)

Model that does not work

class CustomModel(keras.Model):
    def __init__(self, num_classes, name = None):
        super().__init__(name = name)

        self._dense1 = tf.keras.layers.Dense(64)
        self._dense2 = tf.keras.layers.Dense(num_classes)

    @tf.function
    def call(self, X, training=False):
        X = tf.nn.relu(self._dense1(X))
        return self._dense2(X)

Dataset Used

import tensorflow_datasets as tfds

(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split = ["train", "test"],
    shuffle_files = True,
    as_supervised = True,
    with_info = True
)

def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255.0, label

def flatten_img(image, label):
    return tf.reshape(image, [-1, 28 * 28]), label

AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 64

# # Train Dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls = AUTOTUNE)
ds_train = ds_train.map(flatten_img, num_parallel_calls = AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

# # Test Dataset
ds_test = ds_test.map(normalize_img, num_parallel_calls = AUTOTUNE)
ds_test = ds_test.map(flatten_img, num_parallel_calls = AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.prefetch(AUTOTUNE)

Custom Training Loop

model = CustomModel(10)

num_epochs = 5
optimizer = keras.optimizers.Adam()
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
acc_metric = keras.metrics.SparseCategoricalAccuracy()

@tf.function
def train_epoch(x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x, training = True)
        loss = loss_fn(y, y_pred)
    
    # Getting Gradients
    gradients = tape.gradient(loss, model.trainable_weights)

    # Back Prop
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    acc_metric.update_state(y, y_pred)

    return loss

# Training Loop
for epoch in range(num_epochs):
    print(f"nStart of Training Epoch {epoch + 1}")

    for batch_idx, (x_batch, y_batch) in tqdm(enumerate(ds_train), total=len(ds_train)):
        loss = train_epoch(x_batch, y_batch)

    print(f"Accuracy :- {acc_metric.result()}, Loss :- {loss}")
    acc_metric.reset_states()

Flatten Layer Model Stats

No Flatten Layer Model Stats

Asked By: Prithwiraj Mitra

||

Answers:

There are two errors in your flatten_img function. First of all, as Frightera pointed out, you don’t need to divide by 255 again. Second, you have one dimension too many. Your flatten_img function outputs a tensor with a shape of (1, 784) instead of (784). That’s the reason why the model needs a Flatten() layer in order to work. Replace your code with this and it will work just fine:

return tf.reshape(image, [28 * 28]), label
Answered By: MikeElmwood