Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Simplistic RPC implementation. | |
| Exposes all functions of a Server object. | |
| This code is for demonstration purposes only, and does not include certain | |
| security protections. It is not meant to be run on an untrusted network or | |
| in a production environment. | |
| """ | |
| import importlib | |
| import os | |
| import pickle | |
| import sys | |
| import _thread | |
| import traceback | |
| import socket | |
| import logging | |
| LOG = logging.getLogger(__name__) | |
| # default | |
| PORT = 12032 | |
| safe_modules = { | |
| 'numpy', | |
| 'numpy.core.multiarray', | |
| } | |
| class RestrictedUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| # Only allow safe modules. | |
| if module in safe_modules: | |
| return getattr(importlib.import_module(module), name) | |
| # Forbid everything else. | |
| raise pickle.UnpicklingError("global '%s.%s' is forbidden" % | |
| (module, name)) | |
| class FileSock: | |
| " wraps a socket so that it is usable by pickle/cPickle " | |
| def __init__(self,sock): | |
| self.sock = sock | |
| self.nr=0 | |
| def write(self, buf): | |
| # print("sending %d bytes"%len(buf)) | |
| #self.sock.sendall(buf) | |
| # print("...done") | |
| bs = 512 * 1024 | |
| ns = 0 | |
| while ns < len(buf): | |
| sent = self.sock.send(buf[ns:ns + bs]) | |
| ns += sent | |
| def read(self,bs=512*1024): | |
| #if self.nr==10000: pdb.set_trace() | |
| self.nr+=1 | |
| # print("read bs=%d"%bs) | |
| b = [] | |
| nb = 0 | |
| while len(b)<bs: | |
| # print(' loop') | |
| rb = self.sock.recv(bs - nb) | |
| if not rb: break | |
| b.append(rb) | |
| nb += len(rb) | |
| return b''.join(b) | |
| def readline(self): | |
| # print("readline!") | |
| """may be optimized...""" | |
| s=bytes() | |
| while True: | |
| c=self.read(1) | |
| s+=c | |
| if len(c)==0 or chr(c[0])=='\n': | |
| return s | |
| class ClientExit(Exception): | |
| pass | |
| class ServerException(Exception): | |
| pass | |
| class Server: | |
| """ | |
| server protocol. Methods from classes that subclass Server can be called | |
| transparently from a client | |
| """ | |
| def __init__(self, s, logf=sys.stderr, log_prefix=''): | |
| self.logf = logf | |
| self.log_prefix = log_prefix | |
| # connection | |
| self.conn = s | |
| self.fs = FileSock(s) | |
| def log(self, s): | |
| self.logf.write("Sever log %s: %s\n" % (self.log_prefix, s)) | |
| def one_function(self): | |
| """ | |
| Executes a single function with associated I/O. | |
| Protocol: | |
| - the arguments and results are serialized with the pickle protocol | |
| - client sends : (fname,args) | |
| fname = method name to call | |
| args = tuple of arguments | |
| - server sends result: (rid,st,ret) | |
| rid = request id | |
| st = None, or exception if there was during execution | |
| ret = return value or None if st!=None | |
| """ | |
| try: | |
| (fname, args) = RestrictedUnpickler(self.fs).load() | |
| except EOFError: | |
| raise ClientExit("read args") | |
| self.log("executing method %s"%(fname)) | |
| st = None | |
| ret = None | |
| try: | |
| f=getattr(self,fname) | |
| except AttributeError: | |
| st = AttributeError("unknown method "+fname) | |
| self.log("unknown method") | |
| try: | |
| ret = f(*args) | |
| except Exception as e: | |
| # due to a bug (in mod_python?), ServerException cannot be | |
| # unpickled, so send the string and make the exception on the client side | |
| #st=ServerException( | |
| # "".join(traceback.format_tb(sys.exc_info()[2]))+ | |
| # str(e)) | |
| st="".join(traceback.format_tb(sys.exc_info()[2]))+str(e) | |
| self.log("exception in method") | |
| traceback.print_exc(50,self.logf) | |
| self.logf.flush() | |
| LOG.info("return") | |
| try: | |
| pickle.dump((st ,ret), self.fs, protocol=4) | |
| except EOFError: | |
| raise ClientExit("function return") | |
| def exec_loop(self): | |
| """ main execution loop. Loops and handles exit states""" | |
| self.log("in exec_loop") | |
| try: | |
| while True: | |
| self.one_function() | |
| except ClientExit as e: | |
| self.log("ClientExit %s"%e) | |
| except socket.error as e: | |
| self.log("socket error %s"%e) | |
| traceback.print_exc(50,self.logf) | |
| except EOFError: | |
| self.log("EOF during communication") | |
| traceback.print_exc(50,self.logf) | |
| except BaseException: | |
| # unexpected | |
| traceback.print_exc(50,sys.stderr) | |
| sys.exit(1) | |
| LOG.info("exit sever") | |
| def exec_loop_cleanup(self): | |
| pass | |
| ################################################################### | |
| # spying stuff | |
| def get_ps_stats(self): | |
| ret='' | |
| f=os.popen("echo ============ `hostname` uptime:; uptime;"+ | |
| "echo ============ self:; "+ | |
| "ps -p %d -o pid,vsize,rss,%%cpu,nlwp,psr; "%os.getpid()+ | |
| "echo ============ run queue:;"+ | |
| "ps ar -o user,pid,%cpu,%mem,ni,nlwp,psr,vsz,rss,cputime,command") | |
| for l in f: | |
| ret+=l | |
| return ret | |
| class Client: | |
| """ | |
| Methods of the server object can be called transparently. Exceptions are | |
| re-raised. | |
| """ | |
| def __init__(self, HOST, port=PORT, v6=False): | |
| socktype = socket.AF_INET6 if v6 else socket.AF_INET | |
| sock = socket.socket(socktype, socket.SOCK_STREAM) | |
| LOG.info("connecting to %s:%d, socket type: %s", HOST, port, socktype) | |
| sock.connect((HOST, port)) | |
| self.sock = sock | |
| self.fs = FileSock(sock) | |
| def generic_fun(self, fname, args): | |
| # int "gen fun",fname | |
| pickle.dump((fname, args), self.fs, protocol=4) | |
| return self.get_result() | |
| def get_result(self): | |
| (st, ret) = RestrictedUnpickler(self.fs).load() | |
| if st!=None: | |
| raise ServerException(st) | |
| else: | |
| return ret | |
| def __getattr__(self,name): | |
| return lambda *x: self.generic_fun(name,x) | |
| def run_server(new_handler, port=PORT, report_to_file=None, v6=False): | |
| HOST = '' # Symbolic name meaning the local host | |
| socktype = socket.AF_INET6 if v6 else socket.AF_INET | |
| s = socket.socket(socktype, socket.SOCK_STREAM) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| LOG.info("bind %s:%d", HOST, port) | |
| s.bind((HOST, port)) | |
| s.listen(5) | |
| LOG.info("accepting connections") | |
| if report_to_file is not None: | |
| LOG.info('storing host+port in %s', report_to_file) | |
| open(report_to_file, 'w').write('%s:%d ' % (socket.gethostname(), port)) | |
| while True: | |
| try: | |
| conn, addr = s.accept() | |
| except socket.error as e: | |
| if e[1]=='Interrupted system call': continue | |
| raise | |
| LOG.info('Connected to %s', addr) | |
| ibs = new_handler(conn) | |
| tid = _thread.start_new_thread(ibs.exec_loop,()) | |
| LOG.debug("Thread ID: %d", tid) | |