Tonic's picture
Upload folder using huggingface_hub
7873945 verified
raw
history blame
6.77 kB
from __future__ import annotations
import json
from typing import Any, List, Dict, Optional, Literal, ClassVar
from gradio.components import Component
from gradio.data_classes import FileData, GradioModel
from gradio.events import Events, EventListener
import matplotlib.pyplot as plt
import io
import base64
# Import functions from msa.py
from .msa import (
DrawMSA,
DrawConsensusHisto,
DrawSeqLogo,
DrawAnnotation,
DrawComplexMSA,
GetColorMap,
)
class MSAPlotData(GradioModel):
msa: List[str]
seq_names: Optional[List[str]] = None
start: Optional[int] = None
end: Optional[int] = None
color_map: Optional[Dict[str, List[float]]] = None
plot_type: Literal["msa", "consensus", "logo", "annotation", "complex"] = "msa"
panels: List[str] = ["msa"]
panel_height_ratios: Optional[List[float]] = None
panel_params: Optional[List[Dict[str, Any]]] = None
wrap: Optional[int] = None
figsize: Optional[List[float]] = None
annotations: Optional[List[List[Any]]] = None
EVENTS: ClassVar[List[Events | EventListener]] = []
def __init__(self, data: Any = None, **kwargs):
super().__init__(**kwargs)
if data is not None:
self.__dict__.update(data)
@classmethod
def get_events(cls) -> Dict[str, Any]:
return {} # MSAPlotData has no events
@classmethod
def get_description(cls) -> str:
return "Helper class for MSAPlot data"
class MSAPlot(Component):
"""
Creates a Multiple Sequence Alignment (MSA) plot component.
"""
# EVENTS = {
# "change": None,
# "clear": None
# }
EVENTS = [
Events.change,
EventListener("clear", doc="Triggered when the plot is cleared.")
]
data_model = MSAPlotData
def __init__(
self,
value: Any | None = None,
*,
label: str | None = None,
every: float | None = None,
show_label: bool | None = None,
container: bool = True,
scale: int | None = None,
min_width: int = 160,
visible: bool = True,
elem_id: str | None = None,
elem_classes: list[str] | str | None = None,
render: bool = True,
key: int | str | None = None,
):
super().__init__(
label=label,
every=every,
show_label=show_label,
container=container,
scale=scale,
min_width=min_width,
visible=visible,
elem_id=elem_id,
elem_classes=elem_classes,
render=render,
key=key,
value=value,
)
def preprocess(self, payload: MSAPlotData | None) -> MSAPlotData | None:
return payload
def postprocess(self, value: MSAPlotData) -> Dict[str, Any]:
if value is None:
return None
fig, ax = plt.subplots(figsize=value.figsize or (10, 5))
color_map = value.color_map or GetColorMap(msa=value.msa)
if value.plot_type == "msa":
DrawMSA(value.msa, seq_names=value.seq_names, start=value.start, end=value.end, color_map=color_map, ax=ax)
elif value.plot_type == "consensus":
DrawConsensusHisto(value.msa, color_map=color_map, start=value.start, end=value.end, ax=ax)
elif value.plot_type == "logo":
DrawSeqLogo(value.msa, color_map=color_map, start=value.start, end=value.end, ax=ax)
elif value.plot_type == "annotation":
if value.annotations:
DrawAnnotation(value.msa, value.annotations, color_map=color_map, start=value.start, end=value.end, ax=ax)
else:
raise ValueError("Annotations are required for annotation plot type")
elif value.plot_type == "complex":
panel_functions = {
"msa": DrawMSA,
"consensus": DrawConsensusHisto,
"logo": DrawSeqLogo,
"annotation": DrawAnnotation,
}
panels = [panel_functions[p] for p in value.panels]
DrawComplexMSA(
value.msa,
panels=panels,
seq_names=value.seq_names,
panel_height_ratios=value.panel_height_ratios,
panel_params=value.panel_params,
color_map=color_map,
start=value.start,
end=value.end,
wrap=value.wrap,
figsize=value.figsize,
)
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
plt.close(fig)
return {
"type": "matplotlib",
"plot": f"data:image/png;base64,{img_base64}",
}
def example_payload(self) -> Any:
return MSAPlotData(
msa=[
"ATGCATGC",
"ATG-ATGC",
"ATGCATGC",
],
seq_names=["Seq1", "Seq2", "Seq3"],
plot_type="complex",
panels=["msa", "consensus", "logo"],
)
def example_value(self) -> Any:
return self.example_payload()
@classmethod
def get_events(cls) -> Dict[str, Any]:
return {event.value if isinstance(event, Events) else event.name: event.doc for event in cls.EVENTS}
@classmethod
def get_description(cls) -> str:
return "Creates a Multiple Sequence Alignment (MSA) plot component."
def api_info(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"msa": {
"type": "array",
"items": {"type": "string"},
"description": "List of sequences in the multiple sequence alignment",
},
"seq_names": {
"type": "array",
"items": {"type": "string"},
"description": "List of sequence names",
},
"plot_type": {
"type": "string",
"enum": ["msa", "consensus", "logo", "annotation", "complex"],
"description": "Type of plot to generate",
},
"panels": {
"type": "array",
"items": {"type": "string"},
"description": "List of panels to include in a complex plot",
},
},
"required": ["msa"],
}