Why do sometimes CNN models predict just one class out of all others?

Question:

I am relatively new to the deep learning landscape, so please don’t be as mean as Reddit! It seems like a general question so I won’t be giving my code here as it doesn’t seem necessary (if it is, here’s the link to colab)

A bit about the data: You can find the original data here. It is a downsized version of the original dataset of 82 GB.

Once I trained my CNN on this, it predicts ‘No Diabetic Retinopathy’ (No DR) every single time, leading to an accuracy of 73%. Is the reason for this is just the vast amount of No DR images or something else? I have no idea! The 5 classes I have for prediction are ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"].

It’s probably just bad code, was hoping you guys could help

Asked By: Divith

||

Answers:

I was about to comment:

A more rigorous approach would be to start measuring your dataset balance: how many images of each class do you have? This will likely give an answer to your question.

But couldn’t help myself look at the link you gave. Kaggle already gives you an overview of the dataset:

enter image description here

Quick calculation: 25,812 / 35,126 * 100 = 73%. That’s interesting, you said you had an accuracy of 74%. Your model is learning on an inbalanced dataset, with the first class being over represented, 25k/35k is enormous. My hypothesis is that your model keeps predicting the first class which means that on average you’ll end up with an accuracy of 74%.

What you should do is balance your dataset. For example by only allowing 35,126 - 25,810 = 9,316 examples from the first class to appear during an epoch. Even better, balance your dataset over all classes such that each class will only appear n times each, per epoch.

Answered By: Ivan

As Ivan already noted you have a class imbalance problem. This can be resolved via:

  1. Online hard negative mining: at each iteration after computing the loss, you can sort all elements in the batch belonging to "no DR" class and keep only the worst k. Then you estimate the gradient only using these worse k and discard all the rest.
    see, e.g.:
    Abhinav Shrivastava, Abhinav Gupta and Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)

  2. Focal loss: a modification for the "vanilla" cross entropy loss can be used to tackle class imbalance.


Related posts this and this.

Answered By: Shai

The problem of detecting diabetic retinopathy is a difficult problem, and the dataset reflects the reality, most of the tests performed in real life resulted in no RD. But to solve your problem, there are some alternatives.

1 – The first thing is to use extra training data. In the kaggle competition all teams used extra data to train the models, choosing the correct dataset can reduce the imbalance.

2 – Another alternative is data augmentation

3 – You are using CrossEntropyLoss in your code, you can use the weight of the classes to try to balance the adjustments of your model:

 data_lebel = data_lebel.reset_index()
 class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=np.array([0, 1, 2, 3, 4]),
                                                          y=data_lebel['level'].values)
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(config.DEVICE)

criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

Answered By: Thiago Rainmaker