zsp / data.py
Massimo G. Totaro
QOL and gradio upgrade
ddc1bd3
raw
history blame
8.53 kB
import dataframe_image as dfi
from math import ceil
import matplotlib.pyplot as plt
import pandas as pd
from re import match
import seaborn as sns
from model import Model
class Data:
"""Container for input and output data"""
# Initialise empty model as static class member for efficiency
model = Model()
def parse_seq(self, src: str):
"""Parse input sequence"""
self.seq = src.strip().upper().replace('\n', '')
if not all(x in self.model.alphabet for x in self.seq):
raise RuntimeError("Unrecognised characters in sequence")
def parse_sub(self, trg: str):
"""Parse input substitutions"""
self.mode = None
self.sub = list()
self.trg = trg.strip().upper().split()
self.resi = list()
# Identify running mode
if len(self.trg) == 1 and len(self.trg[0]) == len(self.seq) and match(r'^\w+$', self.trg[0]):
# If single string of same length as sequence, seq vs seq mode
self.mode = 'MUT'
for resi, (src, trg) in enumerate(zip(self.seq, self.trg[0]), 1):
if src != trg:
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
else:
if all(match(r'\d+', x) for x in self.trg):
# If all strings are numbers, deep mutational scanning mode
self.mode = 'DMS'
for resi in map(int, self.trg):
src = self.seq[resi-1]
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
# If all strings are of the form X#Y, single substitution mode
self.mode = 'MUT'
self.sub = self.trg
self.resi = [int(x[1:-1]) for x in self.trg]
for s, *resi, _ in self.trg:
if self.seq[int(''.join(resi))-1] != s:
raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
else:
self.mode = 'TMS'
for resi, src in enumerate(self.seq, 1):
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
self.sub.append(f"{src}{resi}{trg}")
self.resi.append(resi)
self.sub = pd.DataFrame(self.sub, columns=['0'])
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
"initialise data"
# if model has changed, load new model
if self.model.model_name != model_name:
self.model_name = model_name
self.model = Model(model_name)
self.parse_seq(src)
self.offset = 0
self.parse_sub(trg)
self.scoring_strategy = scoring_strategy
self.token_probs = None
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
self.out_img = f'{out_file}.png'
self.out_csv = f'{out_file}.csv'
def parse_output(self) -> None:
"format output data for visualisation"
if self.mode == 'TMS':
self.process_tms_mode()
self.out.to_csv(self.out_csv, float_format='%.2f')
else:
if self.mode == 'DMS':
self.sort_by_residue_and_score()
elif self.mode == 'MUT':
self.sort_by_score()
else:
raise RuntimeError(f"Unrecognised mode {self.mode}")
out_df = (self.out.style
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
.hide(axis=0).hide(axis=1)
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8))
dfi.export(out_df, self.out_img, max_rows=-1, max_cols=-1, dpi=300)
self.out.to_csv(self.out_csv, float_format='%.2f', index=False, header=False)
def sort_by_score(self):
self.out = self.out.sort_values(self.model_name, ascending=False)
def sort_by_residue_and_score(self):
self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
.sort_values(['resi', self.model_name], ascending=[True,False])
.groupby(['resi'])
.head(19)
.drop(['resi'], axis=1))
self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
def process_tms_mode(self):
self.out = self.assign_resi_and_group()
self.out = self.concat_and_set_axis()
self.out /= self.out.abs().max().max()
divs = self.calculate_divs()
ncols = min(divs, key=lambda x: abs(x-60))
nrows = ceil(self.out.shape[1]/ncols)
ncols = self.adjust_ncols(ncols, nrows)
self.plot_heatmap(ncols, nrows)
def assign_resi_and_group(self):
return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
.groupby(['resi'])
.head(19))
def concat_and_set_axis(self):
return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
.pipe(self.create_dataframe)
.sort_values(['0'], ascending=[True])
.drop(['resi', '0'], axis=1)
.set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
.astype(float)
) for x in range(self.out.shape[0]//19)]
, axis=1)
.set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns'))
def create_dataframe(self, df):
return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True)
def calculate_divs(self):
return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60]
def adjust_ncols(self, ncols, nrows):
while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]:
ncols -= 1
return ncols + 1
def plot_heatmap(self, ncols, nrows):
if nrows < 2:
self.plot_single_heatmap()
else:
self.plot_multiple_heatmaps(ncols, nrows)
plt.savefig(self.out_img, format='png', dpi=300)
def plot_single_heatmap(self):
fig = plt.figure(figsize=(12, 6))
sns.heatmap(self.out
, cmap='RdBu'
, cbar=False
, square=True
, xticklabels=1
, yticklabels=1
, center=0
, annot=self.out.map(lambda x: ' ' if x != 0 else '·')
, fmt='s'
, annot_kws={'size': 'xx-large'})
fig.tight_layout()
def plot_multiple_heatmaps(self, ncols, nrows):
fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows))
for i in range(nrows):
tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
label = tmp.map(lambda x: ' ' if x != 0 else '·')
sns.heatmap(tmp
, ax=ax[i]
, cmap='RdBu'
, cbar=False
, square=True
, xticklabels=1
, yticklabels=1
, center=0
, annot=label
, fmt='s'
, annot_kws={'size': 'xx-large'})
ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
fig.tight_layout()
def calculate(self):
"run model and parse output"
self.model.run_model(self)
self.parse_output()
return self
def csv(self):
"return output data"
return self.out_csv
def image(self):
"return output data"
return self.out_img