Calculating embedding overload problems with BERT

Question:

I’m trying to calculate the embedding of a sentence using BERT. After I input the sentence into BERT, I calculate the Mean-pooling, which is used as the embedding of the sentence.

Problem

My code can calculate the embedding of sentences, but the computational cost is very high. I don’t know what’s wrong and I hope someone can help me.

Install BERT

import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

Get Embedding Function

# get the word embedding from BERT
def get_word_embedding(text:str):
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # Batch size 1
    outputs = model(input_ids)
    last_hidden_states = outputs[1]  
    # The last hidden-state is the first element of the output tuple
    return last_hidden_states[0]

Data

The maximum number of words in the text is 50. I calculate the entity+text embedding

enter image description here

Run code

entity_desc is my data.
It’s this step that overloads my computer every time I run it.
Please help me!!!

I was use RAM 80GB machine in Colab.

entity_embedding = {}
for i in range(len(entity_desc)):
    entity = entity_desc['entity'][i]
    text = entity_desc['text'][i]
    entity += ' ' + text
    entity_embedding[entity_desc['entity_id'][i]] = get_word_embedding(entity)
Asked By: edamame

||

Answers:

You might be storing sentence embedding in the GPU. Try to move it to cpu before returning it.

# get the word embedding from BERT
def get_word_embedding(text:str):
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # Batch size 1
    outputs = model(input_ids)
    last_hidden_states = outputs[1]  
    # The last hidden-state is the first element of the output tuple
    return last_hidden_states[0].detach().cpu()
Answered By: joe32140

I fixed the problem.
The reason for the memory overload was that I wasn’t saving the tensor to the GPU, so I made the following changes to the code.

model = model.to(device)


import torch
# get the word embedding from BERT
def get_word_embedding(text:str):
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)  # Batch size 1
    input_ids = input_ids.to(device)

    outputs = model(input_ids)
    last_hidden_states = outputs[1]
    last_hidden_states = last_hidden_states.to(device)  
    # The last hidden-state is the first element of the output tuple
    return last_hidden_states[0].detach().to(device)

Answered By: edamame