Why is unpickling a custom class with a restricted Unpickler forbidden although it is allowed in find_class?
Question:
I’m required to run some code repeatedly to train a model. I found that using pickle for saving my object after one iteration of the code was useful, and I could load it and use it in my second iteration.
But as pickle has the security issue, I wanted to use the restricted_loads option. However I can’t seem to get it working for custom classes. Here’s a smaller block of code where I get the same error:
import builtins
import io
import os
import pickle
safe_builtins = {
'range',
'complex',
'set',
'frozenset',
'slice',
}
allow_classes = {
'__main__.Shape'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow safe classes from builtins.
if module == "builtins" and name in safe_builtins | allow_classes:
return getattr(builtins, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(s):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(s)).load()
class Person:
def __init__(
self,
name: str,
age: int,
):
self.name = name
self.age = age
class Shape:
def __init__(
self,
name: Person,
n: int = 50,
):
self.person = Person(
name = name,
age = "10",
)
self.n = n
s = Shape(
name = "name1",
n = 30,
)
filepath = os.path.join(os.getcwd(), "temp.pkl")
with open(filepath, 'wb') as outp:
pickle.dump(s, outp, -1)
with open(filepath, 'rb') as inp:
x = restricted_loads(inp.read())
Error:
UnpicklingError Traceback (most recent call last)
Cell In[20], line 63
60 pickle.dump(s, outp, -1)
62 with open(filepath, 'rb') as inp:
---> 63 x = restricted_loads(inp.read())
Cell In[20], line 30, in restricted_loads(s)
28 def restricted_loads(s):
29 """Helper function analogous to pickle.loads()."""
---> 30 return RestrictedUnpickler(io.BytesIO(s)).load()
Cell In[20], line 25, in RestrictedUnpickler.find_class(self, module, name)
23 return getattr(builtins, name)
24 # Forbid everything else.
---> 25 raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
26 (module, name))
UnpicklingError: global '__main__.Shape' is forbidden
Answers:
You have only allowed classes which come from the module builtins
.
But __main__.Shape
is a class with the name Shape
in the module __main__
, not a class with the name __main__.Shape
in the module builtins
.
So an obvious fix would be to change
if module == "builtins" and name in safe_builtins | allow_classes:
return getattr(builtins, name)
to
if module == "builtins" and name in safe_builtins:
return getattr(builtins, name)
elif module == "__main__" and name == "Shape":
return Shape
I’m required to run some code repeatedly to train a model. I found that using pickle for saving my object after one iteration of the code was useful, and I could load it and use it in my second iteration.
But as pickle has the security issue, I wanted to use the restricted_loads option. However I can’t seem to get it working for custom classes. Here’s a smaller block of code where I get the same error:
import builtins
import io
import os
import pickle
safe_builtins = {
'range',
'complex',
'set',
'frozenset',
'slice',
}
allow_classes = {
'__main__.Shape'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Only allow safe classes from builtins.
if module == "builtins" and name in safe_builtins | allow_classes:
return getattr(builtins, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(s):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(s)).load()
class Person:
def __init__(
self,
name: str,
age: int,
):
self.name = name
self.age = age
class Shape:
def __init__(
self,
name: Person,
n: int = 50,
):
self.person = Person(
name = name,
age = "10",
)
self.n = n
s = Shape(
name = "name1",
n = 30,
)
filepath = os.path.join(os.getcwd(), "temp.pkl")
with open(filepath, 'wb') as outp:
pickle.dump(s, outp, -1)
with open(filepath, 'rb') as inp:
x = restricted_loads(inp.read())
Error:
UnpicklingError Traceback (most recent call last)
Cell In[20], line 63
60 pickle.dump(s, outp, -1)
62 with open(filepath, 'rb') as inp:
---> 63 x = restricted_loads(inp.read())
Cell In[20], line 30, in restricted_loads(s)
28 def restricted_loads(s):
29 """Helper function analogous to pickle.loads()."""
---> 30 return RestrictedUnpickler(io.BytesIO(s)).load()
Cell In[20], line 25, in RestrictedUnpickler.find_class(self, module, name)
23 return getattr(builtins, name)
24 # Forbid everything else.
---> 25 raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
26 (module, name))
UnpicklingError: global '__main__.Shape' is forbidden
You have only allowed classes which come from the module builtins
.
But __main__.Shape
is a class with the name Shape
in the module __main__
, not a class with the name __main__.Shape
in the module builtins
.
So an obvious fix would be to change
if module == "builtins" and name in safe_builtins | allow_classes:
return getattr(builtins, name)
to
if module == "builtins" and name in safe_builtins:
return getattr(builtins, name)
elif module == "__main__" and name == "Shape":
return Shape