from __future__ import annotations import json from typing import Any, List, Dict, Optional, Literal, ClassVar from gradio.components import Component from gradio.data_classes import FileData, GradioModel from gradio.events import Events, EventListener import matplotlib.pyplot as plt import io import base64 # Import functions from msa.py from .msa import ( DrawMSA, DrawConsensusHisto, DrawSeqLogo, DrawAnnotation, DrawComplexMSA, GetColorMap, ) class MSAPlotData(GradioModel): msa: List[str] seq_names: Optional[List[str]] = None start: Optional[int] = None end: Optional[int] = None color_map: Optional[Dict[str, List[float]]] = None plot_type: Literal["msa", "consensus", "logo", "annotation", "complex"] = "msa" panels: List[str] = ["msa"] panel_height_ratios: Optional[List[float]] = None panel_params: Optional[List[Dict[str, Any]]] = None wrap: Optional[int] = None figsize: Optional[List[float]] = None annotations: Optional[List[List[Any]]] = None EVENTS: ClassVar[List[Events | EventListener]] = [] def __init__(self, data: Any = None, **kwargs): super().__init__(**kwargs) if data is not None: self.__dict__.update(data) @classmethod def get_events(cls) -> Dict[str, Any]: return {} # MSAPlotData has no events @classmethod def get_description(cls) -> str: return "Helper class for MSAPlot data" class MSAPlot(Component): """ Creates a Multiple Sequence Alignment (MSA) plot component. """ # EVENTS = { # "change": None, # "clear": None # } EVENTS = [ Events.change, EventListener("clear", doc="Triggered when the plot is cleared.") ] data_model = MSAPlotData def __init__( self, value: Any | None = None, *, label: str | None = None, every: float | None = None, show_label: bool | None = None, container: bool = True, scale: int | None = None, min_width: int = 160, visible: bool = True, elem_id: str | None = None, elem_classes: list[str] | str | None = None, render: bool = True, key: int | str | None = None, ): super().__init__( label=label, every=every, show_label=show_label, container=container, scale=scale, min_width=min_width, visible=visible, elem_id=elem_id, elem_classes=elem_classes, render=render, key=key, value=value, ) def preprocess(self, payload: MSAPlotData | None) -> MSAPlotData | None: return payload def postprocess(self, value: MSAPlotData) -> Dict[str, Any]: if value is None: return None fig, ax = plt.subplots(figsize=value.figsize or (10, 5)) color_map = value.color_map or GetColorMap(msa=value.msa) if value.plot_type == "msa": DrawMSA(value.msa, seq_names=value.seq_names, start=value.start, end=value.end, color_map=color_map, ax=ax) elif value.plot_type == "consensus": DrawConsensusHisto(value.msa, color_map=color_map, start=value.start, end=value.end, ax=ax) elif value.plot_type == "logo": DrawSeqLogo(value.msa, color_map=color_map, start=value.start, end=value.end, ax=ax) elif value.plot_type == "annotation": if value.annotations: DrawAnnotation(value.msa, value.annotations, color_map=color_map, start=value.start, end=value.end, ax=ax) else: raise ValueError("Annotations are required for annotation plot type") elif value.plot_type == "complex": panel_functions = { "msa": DrawMSA, "consensus": DrawConsensusHisto, "logo": DrawSeqLogo, "annotation": DrawAnnotation, } panels = [panel_functions[p] for p in value.panels] DrawComplexMSA( value.msa, panels=panels, seq_names=value.seq_names, panel_height_ratios=value.panel_height_ratios, panel_params=value.panel_params, color_map=color_map, start=value.start, end=value.end, wrap=value.wrap, figsize=value.figsize, ) buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') plt.close(fig) return { "type": "matplotlib", "plot": f"data:image/png;base64,{img_base64}", } def example_payload(self) -> Any: return MSAPlotData( msa=[ "ATGCATGC", "ATG-ATGC", "ATGCATGC", ], seq_names=["Seq1", "Seq2", "Seq3"], plot_type="complex", panels=["msa", "consensus", "logo"], ) def example_value(self) -> Any: return self.example_payload() @classmethod def get_events(cls) -> Dict[str, Any]: return {event.value if isinstance(event, Events) else event.name: event.doc for event in cls.EVENTS} @classmethod def get_description(cls) -> str: return "Creates a Multiple Sequence Alignment (MSA) plot component." def api_info(self) -> Dict[str, Any]: return { "type": "object", "properties": { "msa": { "type": "array", "items": {"type": "string"}, "description": "List of sequences in the multiple sequence alignment", }, "seq_names": { "type": "array", "items": {"type": "string"}, "description": "List of sequence names", }, "plot_type": { "type": "string", "enum": ["msa", "consensus", "logo", "annotation", "complex"], "description": "Type of plot to generate", }, "panels": { "type": "array", "items": {"type": "string"}, "description": "List of panels to include in a complex plot", }, }, "required": ["msa"], }