Spaces:
Runtime error
Runtime error
| from abc import abstractmethod | |
| from typing import List, Dict | |
| from speakers.load.serializable import Serializable | |
| from speakers.processors import ProcessorData, BaseProcessor | |
| from speakers.server.model.flow_data import PayLoad | |
| import logging | |
| class FlowData(Serializable): | |
| """ | |
| 当前runner的任务参数 | |
| """ | |
| def type(self) -> str: | |
| """Type of the Message, used for serialization.""" | |
| def lc_serializable(self) -> bool: | |
| """Whether this class is Processor serializable.""" | |
| return True | |
| class Runner(Serializable): | |
| """ runner的任务id""" | |
| task_id: str | |
| flow_data: FlowData | |
| def type(self) -> str: | |
| """Type of the Runner Message, used for serialization.""" | |
| return 'runner' | |
| def lc_serializable(self) -> bool: | |
| """Whether this class is Processor serializable.""" | |
| return True | |
| # Define a base class for tasks | |
| class BaseTask: | |
| """ | |
| 基础任务处理器由任务管理器创建,用于执行runner flow 的任务,子类实现具体的处理流程 | |
| 此类定义了流程runner task的生命周期 | |
| """ | |
| def __init__(self, preprocess_dict: Dict[str, BaseProcessor]): | |
| self._progress_hooks = [] | |
| self._add_logger_hook() | |
| self._preprocess_dict = preprocess_dict | |
| self.logger = logging.getLogger('base_task_runner') | |
| def from_config(cls, cfg=None): | |
| return cls(preprocess_dict={}) | |
| def _add_logger_hook(self): | |
| """ | |
| 默认的任务日志监听者 | |
| :return: | |
| """ | |
| LOG_MESSAGES = { | |
| 'dispatch_voice_task': 'dispatch_voice_task', | |
| 'saved': 'Saving results', | |
| } | |
| LOG_MESSAGES_SKIP = { | |
| 'skip-no-text': 'No text regions with text! - Skipping', | |
| } | |
| LOG_MESSAGES_ERROR = { | |
| 'error': 'task error', | |
| } | |
| async def ph(task_id: str, runner_stat: str, state: str, finished: bool = False): | |
| if state in LOG_MESSAGES: | |
| self.logger.info(LOG_MESSAGES[state]) | |
| elif state in LOG_MESSAGES_SKIP: | |
| self.logger.warn(LOG_MESSAGES_SKIP[state]) | |
| elif state in LOG_MESSAGES_ERROR: | |
| self.logger.error(LOG_MESSAGES_ERROR[state]) | |
| self.add_progress_hook(ph) | |
| def add_progress_hook(self, ph): | |
| """ | |
| 注册监听器 | |
| :param ph: 监听者 | |
| :return: | |
| """ | |
| self._progress_hooks.append(ph) | |
| async def report_progress(self, task_id: str, runner_stat: str, state: str, finished: bool = False): | |
| """ | |
| 任务通知监听器 | |
| :param task_id: 任务id | |
| :param runner_stat: 任务执行位置 | |
| :param state: 状态 | |
| :param finished: 是否完成 | |
| :return: | |
| """ | |
| for ph in self._progress_hooks: | |
| await ph(task_id, runner_stat, state, finished) | |
| def prepare(cls, payload: PayLoad) -> Runner: | |
| """ | |
| 预处理 | |
| Args: | |
| payload (PayLoad): runner flow data | |
| Raises: | |
| NotImplementedError: This method should be overridden by subclasses. | |
| """ | |
| raise NotImplementedError | |
| async def dispatch(cls, runner: Runner): | |
| """ | |
| 当前runner task具体flow data | |
| Args: | |
| runner (ProcessorData): runner flow data | |
| Raises: | |
| NotImplementedError: This method should be overridden by subclasses. | |
| """ | |
| raise NotImplementedError | |
| def complete(cls, runner: Runner): | |
| """ | |
| 后置处理 | |
| Args: | |
| runner (Runner): runner flow data | |
| Raises: | |
| NotImplementedError: This method should be overridden by subclasses. | |
| """ | |
| raise NotImplementedError | |