Is it possible to freeze only certain embedding weights in the embedding layer in pytorch?

Question:

When using GloVe embedding in NLP tasks, some words from the dataset might not exist in GloVe. Therefore, we instantiate random weights for these unknown words.

Would it be possible to freeze weights gotten from GloVe, and train only the newly instantiated weights?

I am only aware that we can set:
model.embedding.weight.requires_grad = False

But this makes the new words untrainable..

Or are there better ways to extract semantics of words..

Asked By: rcshon

||

Answers:

1. Divide embeddings into two separate objects

One approach would be to use two separate embeddings one for pretrained, another for the one to be trained.

The GloVe one should be frozen, while the one for which there is no pretrained representation would be taken from the trainable layer.

If you format your data that for pretrained token representations it is in smaller range than the tokens without GloVe representation it could be done. Let’s say your pretrained indices are in the range [0, 300], while those without representation are [301, 500]. I would go with something along those lines:

import numpy as np
import torch


class YourNetwork(torch.nn.Module):
    def __init__(self, glove_embeddings: np.array, how_many_tokens_not_present: int):
        self.pretrained_embedding = torch.nn.Embedding.from_pretrained(glove_embeddings)
        self.trainable_embedding = torch.nn.Embedding(
            how_many_tokens_not_present, glove_embeddings.shape[1]
        )
        # Rest of your network setup

    def forward(self, batch):
        # Which tokens in batch do not have representation, should have indices BIGGER
        # than the pretrained ones, adjust your data creating function accordingly
        mask = batch > self.pretrained_embedding.num_embeddings

        # You may want to optimize it, you could probably get away without copy, though
        # I'm not currently sure how
        pretrained_batch = batch.copy()
        pretrained_batch[mask] = 0

        embedded_batch = self.pretrained_embedding(pretrained_batch)

        # Every token without representation has to be brought into appropriate range
        batch -= self.pretrained_embedding.num_embeddings
        # Zero out the ones which already have pretrained embedding
        batch[~mask] = 0
        non_pretrained_embedded_batch = self.trainable_embedding(batch)

        # And finally change appropriate tokens from placeholder embedding created by
        # pretrained into trainable embeddings.
        embedded_batch[mask] = non_pretrained_embedded_batch[mask]

        # Rest of your code
        ...

Let’s say your pretrained indices are in the range [0, 300], while those without representation are [301, 500].

2. Zero gradients for specified tokens.

This one is a bit tricky, but I think it’s pretty concise and easy to implement. So, if you obtain the indices of tokens which got no GloVe representation, you can explicitly zero their gradient after backprop, so those rows will not get updated.

import torch

embedding = torch.nn.Embedding(10, 3)
X = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])

values = embedding(X)
loss = values.mean()

# Use whatever loss you want
loss.backward()

# Let's say those indices in your embedding are pretrained (have GloVe representation)
indices = torch.LongTensor([2, 4, 5])

print("Before zeroing out gradient")
print(embedding.weight.grad)

print("After zeroing out gradient")
embedding.weight.grad[indices] = 0
print(embedding.weight.grad)

And the output of the second approach:

Before zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
After zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
Answered By: Szymon Maszke