from datetime import datetime, timedelta from io import BytesIO from itertools import starmap from json import dumps from more_itertools import windowed from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr from re import compile as compile_re from types import MappingProxyType from typing import Any, Dict, Literal, Mapping, Optional, Self, Sequence from webvtt import Caption, WebVTT from ctp_slack_bot.models.base import Chunk, Content CHUNK_FRAMES_OVERLAP = 1 CHUNK_FRAMES_WINDOW = 5 SPEAKER_SPEECH_TEXT_SEPARATOR = ": " ISO_DATE_TIME_PATTERN = compile_re(r"Start time: (\d{4}-\d{2}-\d{2}(?: \d{2}:\d{2}:\d{2}(?:Z|[+-]\d{2}:\d{2})?)?)") class WebVTTFrame(BaseModel): """Represents a WebVTT frame""" identifier: str start: timedelta end: timedelta speaker: Optional[str] = None speech: str model_config = ConfigDict(frozen=True) @classmethod def from_webvtt_caption(cls, caption: Caption, index: int) -> Self: identifier = caption.identifier if caption.identifier else str(index) start = timedelta(**caption.start_time.__dict__) end = timedelta(**caption.end_time.__dict__) match caption.text.split(SPEAKER_SPEECH_TEXT_SEPARATOR, 1): case [speaker, speech]: return cls(identifier=identifier, start=start, end=end, speaker=speaker, speech=speech) case [speech]: return cls(identifier=identifier, start=start, end=end, speech=speech) class WebVTTContent(Content): """Represents parsed WebVTT content.""" id: str metadata: Mapping[str, Any] = Field(default_factory=dict) start_time: Optional[datetime] frames: Sequence[WebVTTFrame] def get_id(self: Self) -> str: return self.id def get_chunks(self: Self) -> Sequence[Chunk]: windows = (tuple(filter(None, window)) for window in windowed(self.frames, CHUNK_FRAMES_WINDOW, step=CHUNK_FRAMES_WINDOW-CHUNK_FRAMES_OVERLAP)) return tuple(Chunk(text="\n\n".join(": ".join(filter(None, (frame.speaker, frame.speech))) for frame in frames), parent_id=self.get_id(), chunk_id=f"{frames[0].identifier}-{frames[-1].identifier}", metadata={ "start": self.start_time + frames[0].start if self.start_time else None, "end": self.start_time + frames[-1].end if self.start_time else None, "speakers": tuple(frame.speaker for frame in frames if frame.speaker) }) for frames in windows) def get_metadata(self: Self) -> Mapping[str, Any]: return MappingProxyType(self.metadata) @classmethod def __get_start_time(cls, web_vtt: WebVTT) -> Optional[datetime]: try: return next(datetime.fromisoformat(result[0]) for result in map(ISO_DATE_TIME_PATTERN.findall, web_vtt.header_comments) if result) except ValueError: return None @classmethod def from_bytes(cls, id: str, metadata: Mapping[str, Any], buffer: bytes) -> Self: web_vtt = WebVTT.from_buffer(BytesIO(buffer)) frames = tuple(WebVTTFrame.from_webvtt_caption(caption, index) for index, caption in enumerate(web_vtt.captions, 1)) return WebVTTContent(id=id, metadata=MappingProxyType(metadata), start_time=cls.__get_start_time(web_vtt), frames=frames)