guydav commited on
Commit
3f6f50a
·
1 Parent(s): 52fc1db

Minor changes to AllowListImporter

Browse files
Files changed (1) hide show
  1. restrictedpython_code_eval.py +11 -14
restrictedpython_code_eval.py CHANGED
@@ -401,35 +401,32 @@ def _check_correctness(check_program, timeout, task_id, completion_id,
401
  return out_dict
402
 
403
 
404
- class LimitedSysModule:
405
- def __init__(self):
406
- self.maxsize = sys.maxsize
407
 
408
 
409
  class AllowListImporter:
410
  def __init__(self, allowed_imports: List[str]):
411
  self.allowed_imports = allowed_imports
412
-
 
 
 
 
 
413
  def __call__(self, name, globals=None, locals=None, fromlist=(), level=0):
414
  if name.startswith('.'):
415
  raise ImportError("Relative imports are not allowed.")
416
 
417
  if '.' in name:
418
- package_name, _ = name.split('.', 1)
419
 
420
  else:
421
  package_name = name
 
422
 
423
  if package_name == 'sys':
424
- limited_sys = LimitedSysModule()
425
- if name is None:
426
- return limited_sys
427
-
428
- if hasattr(limited_sys, name):
429
- return getattr(limited_sys, name)
430
-
431
- raise ImportError(f"Cannot import {name} from limited sys implementation.")
432
-
433
  if package_name in self.allowed_imports:
434
  return importlib.__import__(name, globals, locals, fromlist, level)
435
 
 
401
  return out_dict
402
 
403
 
404
+ ALLOWED_SYS_NAMES = ['maxsize']
 
 
405
 
406
 
407
  class AllowListImporter:
408
  def __init__(self, allowed_imports: List[str]):
409
  self.allowed_imports = allowed_imports
410
+ inner_sys = importlib.__import__('sys')
411
+ for key in list(dir(inner_sys)):
412
+ if key not in ALLOWED_SYS_NAMES:
413
+ delattr(inner_sys, key)
414
+ self.inner_sys = inner_sys
415
+
416
  def __call__(self, name, globals=None, locals=None, fromlist=(), level=0):
417
  if name.startswith('.'):
418
  raise ImportError("Relative imports are not allowed.")
419
 
420
  if '.' in name:
421
+ package_name, sub_name = name.split('.', 1)
422
 
423
  else:
424
  package_name = name
425
+ sub_name = None
426
 
427
  if package_name == 'sys':
428
+ return self.inner_sys
429
+
 
 
 
 
 
 
 
430
  if package_name in self.allowed_imports:
431
  return importlib.__import__(name, globals, locals, fromlist, level)
432