How to convert a torch tensor into a byte string?

Question:

I’m trying to serialize a torch tensor using protobuf and it seems using BytesIO along with torch.save() doesn’t work. I have tried:

import torch 
import io
x = torch.randn(size=(1,20))
buff = io.BytesIO()
torch.save(x, buff)
print(f'buffer: {buff.read()}')

to no avail as it results in b'' in the output! How should I be going about this?

Asked By: Hossein

||

Answers:

You need to seek to the beginning of the buffer before reading:

import torch 
import io
x = torch.randn(size=(1,20))
buff = io.BytesIO()
torch.save(x, buff)
buff.seek(0)  # <--  this is what you were missing
print(f'buffer: {buff.read()}')

gives you this magnificent output:

buffer: b'PKx03x04x00x00x08x08x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x10x00x12x00archive/data.pklFBx0ex00ZZZZZZZZZZZZZZx80x02ctorch._utilsn_rebuild_tensor_v2nqx00((Xx07x00x00x00storageqx01ctorchnFloatStoragenqx02Xx0fx00x00x00140417054790352qx03Xx03x00x00x00cpuqx04Kx14tqx05QKx00Kx01Kx14x86qx06Kx14Kx01x86qx07x89ccollectionsnOrderedDictnqx08)RqttqnRqx0b.PKx07x08xf3x08ux13xa8x00x00x00xa8x00x00x00PKx03x04x00x00x08x08x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x1cx00x0ex00archive/data/140417054790352FBnx00ZZZZZZZZZZxbaxf3x?xb5xe2xc4=)Rx89xbfMx08x19xbfo%Yxbfx05xc0_xbfx03N4xbexdd_ xc0&xc4xb5?xa7xfdxc4?fxf1$?Llxa6?xeex8ex80xbfx88Uq?.<xd8?{x08xb2?xb3xa3xba>qxcdxbc?xbaxe3hxbdxcanx11xc0PKx07x08Axf3xdc>Px00x00x00Px00x00x00PKx03x04x00x00x08x08x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x0fx003x00archive/versionFB/x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ3nPKx07x08xd1x9egUx02x00x00x00x02x00x00x00PKx01x02x00x00x00x00x08x08x00x00x00x00x00x00xf3x08ux13xa8x00x00x00xa8x00x00x00x10x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00archive/data.pklPKx01x02x00x00x00x00x08x08x00x00x00x00x00x00Axf3xdc>Px00x00x00Px00x00x00x1cx00x00x00x00x00x00x00x00x00x00x00x00x00xf8x00x00x00archive/data/140417054790352PKx01x02x00x00x00x00x08x08x00x00x00x00x00x00xd1x9egUx02x00x00x00x02x00x00x00x0fx00x00x00x00x00x00x00x00x00x00x00x00x00xa0x01x00x00archive/versionPKx06x06,x00x00x00x00x00x00x00x1ex03-x00x00x00x00x00x00x00x00x00x03x00x00x00x00x00x00x00x03x00x00x00x00x00x00x00xc5x00x00x00x00x00x00x00x12x02x00x00x00x00x00x00PKx06x07x00x00x00x00xd7x02x00x00x00x00x00x00x01x00x00x00PKx05x06x00x00x00x00x03x00x03x00xc5x00x00x00x12x02x00x00x00x00'
Answered By: Shai

Use BytesIO.getvalue method.

Apart from seek-ing and read-ing, you can also use the getvalue method of the io.BytesIO object. It does the seekread internally and returns the stored bytes:

In [1121]: x = torch.randn(size=(1,20))                                                                                                                        
buff = io.BytesIO()                                                                                                                                            
torch.save(x, buff)                                                                                                                                            
print(f'buffer: {buff.getvalue()}')                                                                                                                            
                                                                                                                                                               
buffer: b'PKx03x04x00x00x08x08x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x10x00x12x00archive/data.pklFBx0e\
x00ZZZZZZZZZZZZZZx80x02ctorch._utilsn_rebuild_tensor_v2nqx00((Xx07x00x00x00storageqx01ctorchnFloatStoragenqx02Xx01x00x00x000qx03Xx03x00x0
0x00cpuqx04Kx14tqx05QKx00Kx01Kx14x86qx06Kx14Kx01x86qx07x89ccollectionsnOrderedDictnqx08)RqttqnRqx0b.PKx07x08x949f)x9ax00x00x00x9a\
x00x00x00PKx03x04x00x00x08x08x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x0ex00*x00archive/data/0FB&x00ZZZZZZZZZZZZZZZ
ZZZZZZZZZZZZZZZZZZZZZZZxff*x1fxbfMxaax16?x9fBxbd?x14xeexb4xbexbcx83^>l.xba>x8dxc0x1fxbfZx06x03xbexe0(Bxbe^[xf8xbeEx83x9fxbfUoxc0\
xbdxbaXxb7?x83MHxbfxc0x0cxbbxbfxa4sxc9?x84x8bxd9xbfxa1x91xa1xbfxc6,x0c?kxW?PKx07x08/nx02&Px00x00x00Px00x00x00PKx03x04x00x00\
x08x08x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x0fx003x00archive/versionFB/x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
ZZZ3nPKx07x08xd1x9egUx02x00x00x00x02x00x00x00PKx01x02x00x00x00x00x08x08x00x00x00x00x00x00x949f)x9ax00x00x00x9ax00x00x00x1
0x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00x00archive/data.pklPKx01x02x00x00x00x00x08x08x00x00x00x00x00x00/nx02&Px00x
00x00Px00x00x00x0ex00x00x00x00x00x00x00x00x00x00x00x00x00xeax00x00x00archive/data/0PKx01x02x00x00x00x00x08x08x00x00x00x00x0
0x00xd1x9egUx02x00x00x00x02x00x00x00x0fx00x00x00x00x00x00x00x00x00x00x00x00x00xa0x01x00x00archive/versionPKx06x06,x00x00x00\
x00x00x00x00x1ex03-x00x00x00x00x00x00x00x00x00x03x00x00x00x00x00x00x00x03x00x00x00x00x00x00x00xb7x00x00x00x00x00x00x00x
12x02x00x00x00x00x00x00PKx06x07x00x00x00x00xc9x02x00x00x00x00x00x00x01x00x00x00PKx05x06x00x00x00x00x03x00x03x00xb7x00x00
x00x12x02x00x00x00x00' 


getvalue also works the same for any io.StringIO object but instead of bytes, it returns the stored string, as expected.

Answered By: heemayl

Instead of using ByteIO directly you could use pickle.loads/dumps.

According to this discuss thread and the linked PR discussion pytorch’s custom pickling handler does ultimately use torch.save anyways, but needs to serialize less objects, resulting in a 469 length vs 811 length bytes string. Not sure which is faster though.

import pickle, torch

x = torch.randn(size=(1,20))

pickled = pickle.dumps(x)

print(f"length of {len(pickled)} vs 811 for the BytesIO approaches")
# => length of 469 vs 811 for the BytesIO approaches

print(pickled)
# => b'x80x04x95xcax01x00x00x00x00x00x00x8cx0ctorch._utilsx94x8cx12_rebuild_tensor_v2x94x93x94(x8crtorch.storagex94x8cx10_load_from_bytesx94x93x94B?x01x00x00x80x02x8anlxfcx9cFxf9 jxa8Px19.x80x02Mxe9x03.x80x02}qx00(Xx10x00x00x00protocol_versionqx01Mxe9x03Xrx00x00x00little_endianqx02x88Xnx00x00x00type_sizesqx03}qx04(Xx05x00x00x00shortqx05Kx02Xx03x00x00x00intqx06Kx04Xx04x00x00x00longqx07Kx04uu.x80x02(Xx07x00x00x00storageqx00ctorchnFloatStoragenqx01Xx08x00x00x0097668208qx02Xx03x00x00x00cpuqx03Kx14Ntqx04Q.x80x02]qx00Xx08x00x00x0097668208qx01a.x14x00x00x00x00x00x00x00jx0bxeb?4xccx1b?xfaxa9$xbex86xc69>xabxd0#?x14}H>xd16#xc0x04x8ex0b?xadgx1a?xe2xb5g??xf0x83xbfx02Ex89xbf0xf8xe0>x04xc8L?x89xfcx06?xa0xcf)xbfxb9xa7x1b@^xd3xa0xbfJxdaxbdxbfEx95x99xbex94x85x94Rx94Kx00Kx01Kx14x86x94Kx14Kx01x86x94x89x8cx0bcollectionsx94x8cx0bOrderedDictx94x93x94)Rx94tx94Rx94.'

print(pickle.loads(pickled))
# => tensor([[ 0.7662,  0.7422,  0.1888,  ..., -0.2035,  1.0845, -0.9637]])
Answered By: micimize
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.