|
|
|
|
|
import argparse, connexion, os, sys, yaml, json, socket |
|
from netdissect.easydict import EasyDict |
|
from flask import send_from_directory, redirect |
|
from flask_cors import CORS |
|
|
|
|
|
from netdissect.serverstate import DissectionProject |
|
|
|
__author__ = 'Hendrik Strobelt, David Bau' |
|
|
|
CONFIG_FILE_NAME = 'dissect.json' |
|
projects = {} |
|
|
|
app = connexion.App(__name__, debug=False) |
|
|
|
|
|
def get_all_projects(): |
|
res = [] |
|
for key, project in projects.items(): |
|
|
|
res.append({ |
|
'project': key, |
|
'info': { |
|
'layers': [layer['layer'] for layer in project.get_layers()] |
|
} |
|
}) |
|
return sorted(res, key=lambda x: x['project']) |
|
|
|
def get_layers(project): |
|
return { |
|
'request': {'project': project}, |
|
'res': projects[project].get_layers() |
|
} |
|
|
|
def get_units(project, layer): |
|
return { |
|
'request': {'project': project, 'layer': layer}, |
|
'res': projects[project].get_units(layer) |
|
} |
|
|
|
def get_rankings(project, layer): |
|
return { |
|
'request': {'project': project, 'layer': layer}, |
|
'res': projects[project].get_rankings(layer) |
|
} |
|
|
|
def get_levels(project, layer, quantiles): |
|
return { |
|
'request': {'project': project, 'layer': layer, 'quantiles': quantiles}, |
|
'res': projects[project].get_levels(layer, quantiles) |
|
} |
|
|
|
def get_channels(project, layer): |
|
answer = dict(channels=projects[project].get_channels(layer)) |
|
return { |
|
'request': {'project': project, 'layer': layer}, |
|
'res': answer |
|
} |
|
|
|
def post_generate(gen_req): |
|
project = gen_req['project'] |
|
zs = gen_req.get('zs', None) |
|
ids = gen_req.get('ids', None) |
|
return_urls = gen_req.get('return_urls', False) |
|
assert (zs is None) != (ids is None) |
|
ablations = gen_req.get('ablations', []) |
|
interventions = gen_req.get('interventions', None) |
|
|
|
generated = projects[project].generate_images(zs, ids, interventions, |
|
return_urls=return_urls) |
|
return { |
|
'request': gen_req, |
|
'res': generated |
|
} |
|
|
|
def post_features(feat_req): |
|
project = feat_req['project'] |
|
ids = feat_req['ids'] |
|
masks = feat_req.get('masks', None) |
|
layers = feat_req.get('layers', None) |
|
interventions = feat_req.get('interventions', None) |
|
features = projects[project].get_features( |
|
ids, masks, layers, interventions) |
|
return { |
|
'request': feat_req, |
|
'res': features |
|
} |
|
|
|
def post_featuremaps(feat_req): |
|
project = feat_req['project'] |
|
ids = feat_req['ids'] |
|
layers = feat_req.get('layers', None) |
|
interventions = feat_req.get('interventions', None) |
|
featuremaps = projects[project].get_featuremaps( |
|
ids, layers, interventions) |
|
return { |
|
'request': feat_req, |
|
'res': featuremaps |
|
} |
|
|
|
@app.route('/client/<path:path>') |
|
def send_static(path): |
|
""" serves all files from ./client/ to ``/client/<path:path>`` |
|
|
|
:param path: path from api call |
|
""" |
|
return send_from_directory(args.client, path) |
|
|
|
@app.route('/data/<path:path>') |
|
def send_data(path): |
|
""" serves all files from the data dir to ``/dissect/<path:path>`` |
|
|
|
:param path: path from api call |
|
""" |
|
print('Got the data route for', path) |
|
return send_from_directory(args.data, path) |
|
|
|
|
|
@app.route('/') |
|
def redirect_home(): |
|
return redirect('/client/index.html', code=302) |
|
|
|
|
|
def load_projects(directory): |
|
""" |
|
searches for CONFIG_FILE_NAME in all subdirectories of directory |
|
and creates data handlers for all of them |
|
|
|
:param directory: scan directory |
|
:return: null |
|
""" |
|
project_dirs = [] |
|
|
|
search_depth = 2 + directory.count(os.path.sep) |
|
for root, dirs, files in os.walk(directory): |
|
if CONFIG_FILE_NAME in files: |
|
project_dirs.append(root) |
|
|
|
del dirs[:] |
|
elif root.count(os.path.sep) >= search_depth: |
|
del dirs[:] |
|
for p_dir in project_dirs: |
|
print('Loading %s' % os.path.join(p_dir, CONFIG_FILE_NAME)) |
|
with open(os.path.join(p_dir, CONFIG_FILE_NAME), 'r') as jf: |
|
config = EasyDict(json.load(jf)) |
|
dh_id = os.path.split(p_dir)[1] |
|
projects[dh_id] = DissectionProject( |
|
config=config, |
|
project_dir=p_dir, |
|
path_url='data/' + os.path.relpath(p_dir, directory), |
|
public_host=args.public_host) |
|
|
|
app.add_api('server.yaml') |
|
|
|
|
|
CORS(app.app, headers='Content-Type') |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--nodebug", default=False) |
|
parser.add_argument("--address", default="127.0.0.1") |
|
parser.add_argument("--port", default="5001") |
|
parser.add_argument("--public_host", default=None) |
|
parser.add_argument("--nocache", default=False) |
|
parser.add_argument("--data", type=str, default='dissect') |
|
parser.add_argument("--client", type=str, default='client_dist') |
|
|
|
if __name__ == '__main__': |
|
args = parser.parse_args() |
|
for d in [args.data, args.client]: |
|
if not os.path.isdir(d): |
|
print('No directory %s' % d) |
|
sys.exit(1) |
|
args.data = os.path.abspath(args.data) |
|
args.client = os.path.abspath(args.client) |
|
if args.public_host is None: |
|
args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) |
|
app.run(port=int(args.port), debug=not args.nodebug, host=args.address, |
|
use_reloader=False) |
|
else: |
|
args, _ = parser.parse_known_args() |
|
if args.public_host is None: |
|
args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) |
|
load_projects(args.data) |
|
|