LiKenun's picture
Bug fixes; remove references to uncommitted work-in-progress
af0a2bd
raw
history blame
3.73 kB
from datetime import datetime, timedelta
from io import BytesIO
from itertools import starmap
from json import dumps
from more_itertools import windowed
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
from re import compile as compile_re
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Self, Sequence
from webvtt import Caption, WebVTT
from ctp_slack_bot.models.base import Chunk, Content
CHUNK_FRAMES_OVERLAP = 1
CHUNK_FRAMES_WINDOW = 5
SPEAKER_SPEECH_TEXT_SEPARATOR = ": "
ISO_DATE_TIME_PATTERN = compile_re(r"Start time: (\d{4}-\d{2}-\d{2}(?: \d{2}:\d{2}:\d{2}(?:Z|[+-]\d{2}:\d{2})?)?)")
class WebVTTFrame(BaseModel):
"""Represents a WebVTT frame"""
identifier: str
start: timedelta
end: timedelta
speaker: Optional[str] = None
speech: str
model_config = ConfigDict(frozen=True)
@classmethod
def from_webvtt_caption(cls, caption: Caption, index: int) -> Self:
identifier = caption.identifier if caption.identifier else str(index)
start = timedelta(**caption.start_time.__dict__)
end = timedelta(**caption.end_time.__dict__)
match caption.text.split(SPEAKER_SPEECH_TEXT_SEPARATOR, 1):
case [speaker, speech]:
return cls(identifier=identifier, start=start, end=end, speaker=speaker, speech=speech)
case [speech]:
return cls(identifier=identifier, start=start, end=end, speech=speech)
class WebVTTContent(Content):
"""Represents parsed WebVTT content."""
id: str
metadata: Mapping[str, Any] = Field(default_factory=dict)
start_time: Optional[datetime]
frames: Sequence[WebVTTFrame]
def get_id(self: Self) -> str:
return self.id
def get_chunks(self: Self) -> Sequence[Chunk]:
windows = (tuple(filter(None, window))
for window
in windowed(self.frames, CHUNK_FRAMES_WINDOW, step=CHUNK_FRAMES_WINDOW-CHUNK_FRAMES_OVERLAP))
return tuple(Chunk(text="\n\n".join(": ".join(filter(None, (frame.speaker, frame.speech)))
for frame
in frames),
parent_id=self.get_id(),
chunk_id=f"{frames[0].identifier}-{frames[-1].identifier}",
metadata={
"start": self.start_time + frames[0].start if self.start_time else None,
"end": self.start_time + frames[-1].end if self.start_time else None,
"speakers": tuple(frame.speaker for frame in frames if frame.speaker)
})
for frames
in windows)
def get_metadata(self: Self) -> Mapping[str, Any]:
return MappingProxyType(self.metadata)
@classmethod
def __get_start_time(cls, web_vtt: WebVTT) -> Optional[datetime]:
try:
return next(datetime.fromisoformat(result[0])
for result
in map(ISO_DATE_TIME_PATTERN.findall, web_vtt.header_comments)
if result)
except ValueError:
return None
@classmethod
def from_bytes(cls, id: str, metadata: Mapping[str, Any], buffer: bytes) -> Self:
web_vtt = WebVTT.from_buffer(BytesIO(buffer))
frames = tuple(WebVTTFrame.from_webvtt_caption(caption, index)
for index, caption
in enumerate(web_vtt.captions, 1))
return WebVTTContent(id=id, metadata=MappingProxyType(metadata), start_time=cls.__get_start_time(web_vtt), frames=frames)