Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import json | |
from typing import Any, List, Dict, Optional, Literal, ClassVar | |
from gradio.components import Component | |
from gradio.data_classes import FileData, GradioModel | |
from gradio.events import Events, EventListener | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
# Import functions from msa.py | |
from .msa import ( | |
DrawMSA, | |
DrawConsensusHisto, | |
DrawSeqLogo, | |
DrawAnnotation, | |
DrawComplexMSA, | |
GetColorMap, | |
) | |
class MSAPlotData(GradioModel): | |
msa: List[str] | |
seq_names: Optional[List[str]] = None | |
start: Optional[int] = None | |
end: Optional[int] = None | |
color_map: Optional[Dict[str, List[float]]] = None | |
plot_type: Literal["msa", "consensus", "logo", "annotation", "complex"] = "msa" | |
panels: List[str] = ["msa"] | |
panel_height_ratios: Optional[List[float]] = None | |
panel_params: Optional[List[Dict[str, Any]]] = None | |
wrap: Optional[int] = None | |
figsize: Optional[List[float]] = None | |
annotations: Optional[List[List[Any]]] = None | |
EVENTS: ClassVar[List[Events | EventListener]] = [] | |
def __init__(self, data: Any = None, **kwargs): | |
super().__init__(**kwargs) | |
if data is not None: | |
self.__dict__.update(data) | |
def get_events(cls) -> Dict[str, Any]: | |
return {} # MSAPlotData has no events | |
def get_description(cls) -> str: | |
return "Helper class for MSAPlot data" | |
class MSAPlot(Component): | |
""" | |
Creates a Multiple Sequence Alignment (MSA) plot component. | |
""" | |
# EVENTS = { | |
# "change": None, | |
# "clear": None | |
# } | |
EVENTS = [ | |
Events.change, | |
EventListener("clear", doc="Triggered when the plot is cleared.") | |
] | |
data_model = MSAPlotData | |
def __init__( | |
self, | |
value: Any | None = None, | |
*, | |
label: str | None = None, | |
every: float | None = None, | |
show_label: bool | None = None, | |
container: bool = True, | |
scale: int | None = None, | |
min_width: int = 160, | |
visible: bool = True, | |
elem_id: str | None = None, | |
elem_classes: list[str] | str | None = None, | |
render: bool = True, | |
key: int | str | None = None, | |
): | |
super().__init__( | |
label=label, | |
every=every, | |
show_label=show_label, | |
container=container, | |
scale=scale, | |
min_width=min_width, | |
visible=visible, | |
elem_id=elem_id, | |
elem_classes=elem_classes, | |
render=render, | |
key=key, | |
value=value, | |
) | |
def preprocess(self, payload: MSAPlotData | None) -> MSAPlotData | None: | |
return payload | |
def postprocess(self, value: MSAPlotData) -> Dict[str, Any]: | |
if value is None: | |
return None | |
fig, ax = plt.subplots(figsize=value.figsize or (10, 5)) | |
color_map = value.color_map or GetColorMap(msa=value.msa) | |
if value.plot_type == "msa": | |
DrawMSA(value.msa, seq_names=value.seq_names, start=value.start, end=value.end, color_map=color_map, ax=ax) | |
elif value.plot_type == "consensus": | |
DrawConsensusHisto(value.msa, color_map=color_map, start=value.start, end=value.end, ax=ax) | |
elif value.plot_type == "logo": | |
DrawSeqLogo(value.msa, color_map=color_map, start=value.start, end=value.end, ax=ax) | |
elif value.plot_type == "annotation": | |
if value.annotations: | |
DrawAnnotation(value.msa, value.annotations, color_map=color_map, start=value.start, end=value.end, ax=ax) | |
else: | |
raise ValueError("Annotations are required for annotation plot type") | |
elif value.plot_type == "complex": | |
panel_functions = { | |
"msa": DrawMSA, | |
"consensus": DrawConsensusHisto, | |
"logo": DrawSeqLogo, | |
"annotation": DrawAnnotation, | |
} | |
panels = [panel_functions[p] for p in value.panels] | |
DrawComplexMSA( | |
value.msa, | |
panels=panels, | |
seq_names=value.seq_names, | |
panel_height_ratios=value.panel_height_ratios, | |
panel_params=value.panel_params, | |
color_map=color_map, | |
start=value.start, | |
end=value.end, | |
wrap=value.wrap, | |
figsize=value.figsize, | |
) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') | |
plt.close(fig) | |
return { | |
"type": "matplotlib", | |
"plot": f"data:image/png;base64,{img_base64}", | |
} | |
def example_payload(self) -> Any: | |
return MSAPlotData( | |
msa=[ | |
"ATGCATGC", | |
"ATG-ATGC", | |
"ATGCATGC", | |
], | |
seq_names=["Seq1", "Seq2", "Seq3"], | |
plot_type="complex", | |
panels=["msa", "consensus", "logo"], | |
) | |
def example_value(self) -> Any: | |
return self.example_payload() | |
def get_events(cls) -> Dict[str, Any]: | |
return {event.value if isinstance(event, Events) else event.name: event.doc for event in cls.EVENTS} | |
def get_description(cls) -> str: | |
return "Creates a Multiple Sequence Alignment (MSA) plot component." | |
def api_info(self) -> Dict[str, Any]: | |
return { | |
"type": "object", | |
"properties": { | |
"msa": { | |
"type": "array", | |
"items": {"type": "string"}, | |
"description": "List of sequences in the multiple sequence alignment", | |
}, | |
"seq_names": { | |
"type": "array", | |
"items": {"type": "string"}, | |
"description": "List of sequence names", | |
}, | |
"plot_type": { | |
"type": "string", | |
"enum": ["msa", "consensus", "logo", "annotation", "complex"], | |
"description": "Type of plot to generate", | |
}, | |
"panels": { | |
"type": "array", | |
"items": {"type": "string"}, | |
"description": "List of panels to include in a complex plot", | |
}, | |
}, | |
"required": ["msa"], | |
} |