Spaces:
Runtime error
Runtime error
from datetime import datetime, timedelta | |
from io import BytesIO | |
from itertools import starmap | |
from json import dumps | |
from more_itertools import windowed | |
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr | |
from re import compile as compile_re | |
from types import MappingProxyType | |
from typing import Any, Dict, Literal, Mapping, Optional, Self, Sequence | |
from webvtt import Caption, WebVTT | |
from ctp_slack_bot.models.base import Chunk, Content | |
CHUNK_FRAMES_OVERLAP = 1 | |
CHUNK_FRAMES_WINDOW = 5 | |
SPEAKER_SPEECH_TEXT_SEPARATOR = ": " | |
ISO_DATE_TIME_PATTERN = compile_re(r"Start time: (\d{4}-\d{2}-\d{2}(?: \d{2}:\d{2}:\d{2}(?:Z|[+-]\d{2}:\d{2})?)?)") | |
class WebVTTFrame(BaseModel): | |
"""Represents a WebVTT frame""" | |
identifier: str | |
start: timedelta | |
end: timedelta | |
speaker: Optional[str] = None | |
speech: str | |
model_config = ConfigDict(frozen=True) | |
def from_webvtt_caption(cls, caption: Caption, index: int) -> Self: | |
identifier = caption.identifier if caption.identifier else str(index) | |
start = timedelta(**caption.start_time.__dict__) | |
end = timedelta(**caption.end_time.__dict__) | |
match caption.text.split(SPEAKER_SPEECH_TEXT_SEPARATOR, 1): | |
case [speaker, speech]: | |
return cls(identifier=identifier, start=start, end=end, speaker=speaker, speech=speech) | |
case [speech]: | |
return cls(identifier=identifier, start=start, end=end, speech=speech) | |
class WebVTTContent(Content): | |
"""Represents parsed WebVTT content.""" | |
id: str | |
metadata: Mapping[str, Any] = Field(default_factory=dict) | |
start_time: Optional[datetime] | |
frames: Sequence[WebVTTFrame] | |
def get_id(self: Self) -> str: | |
return self.id | |
def get_chunks(self: Self) -> Sequence[Chunk]: | |
windows = (tuple(filter(None, window)) | |
for window | |
in windowed(self.frames, CHUNK_FRAMES_WINDOW, step=CHUNK_FRAMES_WINDOW-CHUNK_FRAMES_OVERLAP)) | |
return tuple(Chunk(text="\n\n".join(": ".join(filter(None, (frame.speaker, frame.speech))) | |
for frame | |
in frames), | |
parent_id=self.get_id(), | |
chunk_id=f"{frames[0].identifier}-{frames[-1].identifier}", | |
metadata={ | |
"start": self.start_time + frames[0].start if self.start_time else None, | |
"end": self.start_time + frames[-1].end if self.start_time else None, | |
"speakers": tuple(frame.speaker for frame in frames if frame.speaker) | |
}) | |
for frames | |
in windows) | |
def get_metadata(self: Self) -> Mapping[str, Any]: | |
return MappingProxyType(self.metadata) | |
def __get_start_time(cls, web_vtt: WebVTT) -> Optional[datetime]: | |
try: | |
return next(datetime.fromisoformat(result[0]) | |
for result | |
in map(ISO_DATE_TIME_PATTERN.findall, web_vtt.header_comments) | |
if result) | |
except ValueError: | |
return None | |
def from_bytes(cls, id: str, metadata: Mapping[str, Any], buffer: bytes) -> Self: | |
web_vtt = WebVTT.from_buffer(BytesIO(buffer)) | |
frames = tuple(WebVTTFrame.from_webvtt_caption(caption, index) | |
for index, caption | |
in enumerate(web_vtt.captions, 1)) | |
return WebVTTContent(id=id, metadata=MappingProxyType(metadata), start_time=cls.__get_start_time(web_vtt), frames=frames) | |