Is there a way to generate nn.Embedding efficiently using for loop?

Question:

I’m Pytorch newbie and I wonder if I can generate nn.Embedding efficiently using for loop.

class Example(nn.Module):
    def __init__(self):
        self.A_embed_dim = 3
        self.B_embed_dim = 3
        self.C_embed_dim = 5

        self.A_embedding = nn.Embedding(
            df.A.max() + 1, self.A_embed_dim
        )
        self.B_embedding = nn.Embedding(
            df.B.max() + 1, self.B.embed_dim
        )
        self.C_embedding = nn.Embedding(
            df.C.max() + 1, self.C.embed_dim
        )

In this case, only 3 columns exist, and it is easy to generate embeddings. But if there are more columns in dataframe(ex, 16 columns A to P), the code is long and doesn’t look clean. Is there a way to create mulitple nn.Embedding using for loop?

Asked By: Dang

||

Answers:

Yes, you can use module list or module dict to do so.

ModuleList:

self.embeddings = nn.ModuleList(
    [
        nn.Embedding(vocab_size, dim) 
        for vocab_size, dim in embedding_args 
    ]
)
# embedding_args = [(5,10), (2, 8)]
Answered By: joe32140
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.