Tonic's picture
Upload folder using huggingface_hub
7873945 verified
raw
history blame
13.4 kB
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.transforms as transforms
from matplotlib.transforms import Affine2D
import math
from matplotlib.patheffects import RendererBase
import matplotlib.patheffects as PathEffects
import seaborn as sns
# Calculate the font size that match the unit length on x-axis or y-axis
def CalculateFontsize(xlim, ylim, ax, fig, rows, cols, unit_scale=1):
# Get axis limits
axXlim = ax.get_xlim()
axYlim = ax.get_ylim()
# Get figure dimensions in pixels
fig_width, fig_height = fig.get_size_inches() * fig.dpi
# Get axis dimensions in pixels
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
ax_width, ax_height = bbox.width * fig.dpi, bbox.height * fig.dpi
# Calculate font size proportional to axis units
fontsize_x = unit_scale * ax_width / (axXlim[1] - axXlim[0]) / cols * (xlim[1] - xlim[0])
fontsize_y = 0.8 * unit_scale * ax_height / (axYlim[1] - axYlim[0]) / rows * (ylim[1] - ylim[0])
# Use the minimum of the two to keep the font size consistent
fontsize = min(fontsize_x, fontsize_y)
return fontsize
def GetStartEnd(msa, start, end):
length = len(msa[0])
if (start == None):
start = 0
elif (start < 0):
start = length + start
if (end == None):
end = length - 1
elif (end < 0):
end = length + end
return start, end
def GetColorMap(preset = None, msa=None, color_order = None, palette = None):
color_map = {}
if preset in ["dna", "nuc", "nucleotide"]:
color_map['A'] = [0,1,0]
color_map['C'] = [1,165/255,0]
color_map['G'] = [1,0,0]
color_map['T'] = [0.5,0.5,1]
else:
if (color_order != None):
for i,c in enumerate(color_order):
color_map[c] = sns.color_palette(palette)[i]
# The case some of the alphabet color is not specified in the order
for a in msa:
for c in a:
if c not in color_map:
size = len(color_map)
color_map[c] = sns.color_palette(palette)[size]
color_map['-'] = [1,1,1] # white
color_map['.'] = [1,1,1] # white
return color_map
def DrawMSA(msa, seq_names = None, start = None, end = None,
axlim = None, color_map = None, palette=None, ax=None, fig=None,
show_char=True):
# Get the canvas attributes.
ax = ax or plt.gca()
fig = fig or ax.get_figure()
renderer = fig.canvas.get_renderer()
height = len(msa)
length = len(msa[0])
# start, end: draw the [start,end] (both inclusive) region of the MSA
start, end = GetStartEnd(msa, start, end)
if (axlim == None):
fontsize = CalculateFontsize(ax.get_xlim(), ax.get_ylim(), ax, fig, height, end - start + 1)
else:
fontsize = CalculateFontsize(axlim[0], axlim[1], ax, fig, height, end - start + 1)
color_map = color_map or GetColorMap(msa=msa, color_order=None, palette=palette)
lengthUnit = 1 / (end - start + 1)
heightUnit = 1 / height
if (axlim != None):
lengthUnit = (axlim[0][1] - axlim[0][0]) / (end - start + 1)
heightUnit = (axlim[1][1] - axlim[1][0]) / height
for i, a in enumerate(msa):
for j,c in enumerate(a[start:end+1]):
linewidth = min(2,fontsize/50)
if show_char:
text = ax.text(x=(j + 0.5)*lengthUnit, y=(i+0.5) * heightUnit, s=c, color="black",
va="center_baseline", ha="center", fontsize=fontsize,
transform=ax.transAxes if axlim == None else ax.transData)
text.set_path_effects([PathEffects.withStroke(linewidth=linewidth,
foreground='w')])
ax.add_patch( patches.Rectangle(xy=(j * lengthUnit, i * heightUnit),
width = lengthUnit, height=heightUnit,
facecolor=color_map[c], linewidth=linewidth, edgecolor="white",
transform=ax.transAxes if axlim == None else ax.transData))
if (axlim == None):
ax.set_xlim(-0.5, end - start + 1 - 0.5)
ax.set_ylim(0-0.5, height-0.5)
# Set the x ticks adaptively at point easy to count
ticks = []
tickLabels = []
candidateSteps = [1,5,10,20,50,100,500,1000,5000,10000]
step = 1
for i,s in enumerate(candidateSteps):
if (s * 5 > end - start + 1):
if (i > 0):
step = candidateSteps[i - 1]
break
tickStart = (int)(start / step) * step + step - 1
if (tickStart != start):
ticks.append(0)
tickLabels.append(start+1)
for i in range(tickStart, end + 1, step):
if (i >= length):
break
ticks.append(i - start)
tickLabels.append(i+1)
if (tickLabels[-1] != min(length, end + 1)):
ticks.append(min(length - 1, end) - start)
tickLabels.append(min(length, end+1))
ax.set_xticks(ticks, tickLabels)
# Set the y
ticks = []
tickLabels = []
for i in range(height):
ticks.append(i)
tickLabels.append(i if seq_names == None else seq_names[i])
ax.set_yticks(ticks, tickLabels)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
return ax, color_map
def GetConsensus(msa, start = None, end = None):
start, end = GetStartEnd(msa, start, end)
consensus = []
consensusComposition = []
for j in range(start, end + 1):
composition = {}
if (j >= len(msa[0])):
consensus.append('')
consensusComposition.append({"":0})
continue
for i,a in enumerate(msa):
if (a[j] not in composition):
composition[a[j]] = 0
composition[a[j]] += 1
result = ""
maxCnt = 0
for c in composition:
if (composition[c] > maxCnt):
maxCnt = composition[c]
result = c
consensus.append(result)
consensusComposition.append(composition)
return consensus, consensusComposition
def SetConsensusAxTicks(consensus, consensusComposition, ax):
ticks = []
tickLabels = []
for i in range(len(consensus)):
ticks.append(i)
label = consensus[i]
for c in consensusComposition[i]:
if (c == label):
continue
if (consensusComposition[i][c] == consensusComposition[i][label]):
label = 'X'
break
tickLabels.append(label)
ax.set_xticks(ticks, tickLabels)
def DrawConsensusHisto(msa, color = [0, 0, 1], color_map = None, start = None, end = None, ax=None):
ax = ax or plt.gca()
consensus, consensusComposition = GetConsensus(msa, start, end)
binHeight = []
colors = []
for i in range(len(consensus)):
if (sum(consensusComposition[i].values()) == 0):
binHeight.append(0)
else:
binHeight.append( consensusComposition[i][consensus[i]] /
sum(consensusComposition[i].values()) )
colors.append([1-binHeight[i] * (1-color[0]),
1-binHeight[i] * (1-color[1]),
1-binHeight[i] * (1-color[2])]) # color-base in blue
ax.bar(x=list(range(0, len(consensus))), height=binHeight,
color = colors,
width=0.95)
ax.set(ylim=(0, 1))
ax.set(xlim=(-0.5, len(consensus)-0.5))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.get_yaxis().set_ticks([])
SetConsensusAxTicks(consensus, consensusComposition, ax)
# Code from pyseqlogo for scale font to one direction
class Scale(RendererBase):
"""Scale alphabets using affine transformation"""
def __init__(self, sx, sy=None):
self._sx = sx
self._sy = sy
def draw_path(self, renderer, gc, tpath, affine, rgbFace):
affine = Affine2D().scale(self._sx, self._sy) + affine
renderer.draw_path(gc, tpath, affine, rgbFace)
def CalculateEntropy(count):
s = sum(count)
if (s == 0):
return 0
return sum([-c/s * math.log2(c/s) for c in count])
def DrawSeqLogo(msa, color_map, alphabet_size = None, start = None, end = None, ax=None):
ax = ax or plt.gca()
fig = ax.get_figure()
start, end = GetStartEnd(msa, start, end)
alphabet_size = len(color_map) - 2 # 2 for "-" and "."
consensus, consensusComposition = GetConsensus(msa, start, end)
# Definition from https://en.wikipedia.org/wiki/Sequence_logo
# Search for appropriate height.
r = []
adjuste = 1 / math.log(2) * (alphabet_size - 1) / (2 * len(msa[0]))
for i,c in enumerate(consensus):
entropy = CalculateEntropy(list(consensusComposition[i].values()))
r.append(math.log2(alphabet_size) - (entropy + adjuste))
ax.set_xlim(-0.5, end-start+1-0.5)
ax.set_ylim(0, max(r))
fontsize = CalculateFontsize(ax.get_xlim(), ax.get_ylim(),
ax, fig, 1, end - start + 1)
lengthUnit = 1
for i,c in enumerate(consensusComposition):
prevy = 0
totalCount = sum(list(c.values()))
if (totalCount == 0):
continue
for j,item in enumerate(sorted(c.items(), key=lambda x:x[1], reverse=True)):
k = item[0]
v = item[1]
text = ax.text(x=i * lengthUnit, y=prevy, s=k, fontsize=fontsize,
va="baseline", ha="center", color = (color_map[k] if (k not in ['.','-']) else [0,0,0]),
transform=ax.transData)
height = v / totalCount * r[i]
tbox = text.get_window_extent(text._renderer).transformed(ax.transData.inverted())
scale = height / (tbox.y1 - prevy)
#print(i, j, height, scale, tbox)
prevy = prevy + height
text.set_path_effects([Scale(1.0, scale)])
ax.axis('off')
SetConsensusAxTicks(consensus, consensusComposition, ax)
# Add the annotation for the sequence alignment
# annotations: a list of numbers like [['a',0,3]]: msa[0..3] (both inclusive) is annotated as the name 'a'
def DrawAnnotation(msa, annotations, color_map=None,start = None, end = None, ax=None):
ax = ax or plt.gca()
fig = ax.get_figure()
start, end = GetStartEnd(msa, start, end)
ax.set_xlim(start - 0.5, end + 0.5)
ax.set_ylim(0, 1)
fontsize = CalculateFontsize(ax.get_xlim(), ax.get_ylim(), ax, fig, 1, end - start + 1)
for a in annotations:
text = ax.text(x=(a[1]+a[2])/2, y=0.5, s=a[0], fontsize=fontsize,
va="center", ha="center", color="black", clip_on=True)
tbox = text.get_window_extent(text._renderer).transformed(ax.transData.inverted())
# Draw the bracket
ax.plot([a[1], a[1]], [0, 1], color="black", clip_on=True)
ax.plot([a[2], a[2]], [0, 1], color="black", clip_on=True)
ax.plot([a[1], tbox.x0], [0.5, 0.5], color="black", clip_on=True)
ax.plot([tbox.x1, a[2]], [0.5, 0.5], color="black", clip_on=True)
ax.axis('off')
# Draw multipanel MSA
def DrawComplexMSA(msa, panels=[], seq_names = None, panel_height_ratios=None, panel_params=None,
color_map=None, start=None, end=None, wrap=None, figsize=None):
color_map = color_map or GetColorMap(msa=msa)
start,end = GetStartEnd(msa, start,end)
wrap = wrap or (end - start + 1)
chunks = math.ceil((end - start + 1) / wrap)
height_ratios = None
if (panel_height_ratios is None):
panel_height_ratios = []
for p in panels:
if (p == DrawMSA):
panel_height_ratios.append(len(msa))
elif (p == DrawAnnotation):
panel_height_ratios.append(0.5)
else:
panel_height_ratios.append(1)
height_ratios = panel_height_ratios * chunks
fig,axes = plt.subplots(len(panels) * chunks, 1, constrained_layout=True,
figsize=figsize, height_ratios = height_ratios)
axidx = 0
for i in range(start, end + 1, wrap):
for j,func in enumerate(panels):
extraParam = {}
if (panel_params is not None):
extraParam = panel_params[j].copy()
if func is DrawMSA:
extraParam['seq_names'] = seq_names
func(msa, color_map = color_map, start=i, end=i+wrap-1,ax=axes[axidx], **extraParam)
axidx += 1
return axes