Spaces:
Runtime error
Runtime error
File size: 3,727 Bytes
a60b3fc 5c7c7e5 a60b3fc 67436c8 af0a2bd a60b3fc af0a2bd 67436c8 af0a2bd a60b3fc af0a2bd 5c7c7e5 a60b3fc 67436c8 a60b3fc 67436c8 a60b3fc 67436c8 a60b3fc 67436c8 af0a2bd 67436c8 a60b3fc 67436c8 a60b3fc 67436c8 a60b3fc af0a2bd a60b3fc 67436c8 a60b3fc 67436c8 a60b3fc 67436c8 af0a2bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from datetime import datetime, timedelta
from io import BytesIO
from itertools import starmap
from json import dumps
from more_itertools import windowed
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
from re import compile as compile_re
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Self, Sequence
from webvtt import Caption, WebVTT
from ctp_slack_bot.models.base import Chunk, Content
CHUNK_FRAMES_OVERLAP = 1
CHUNK_FRAMES_WINDOW = 5
SPEAKER_SPEECH_TEXT_SEPARATOR = ": "
ISO_DATE_TIME_PATTERN = compile_re(r"Start time: (\d{4}-\d{2}-\d{2}(?: \d{2}:\d{2}:\d{2}(?:Z|[+-]\d{2}:\d{2})?)?)")
class WebVTTFrame(BaseModel):
"""Represents a WebVTT frame"""
identifier: str
start: timedelta
end: timedelta
speaker: Optional[str] = None
speech: str
model_config = ConfigDict(frozen=True)
@classmethod
def from_webvtt_caption(cls, caption: Caption, index: int) -> Self:
identifier = caption.identifier if caption.identifier else str(index)
start = timedelta(**caption.start_time.__dict__)
end = timedelta(**caption.end_time.__dict__)
match caption.text.split(SPEAKER_SPEECH_TEXT_SEPARATOR, 1):
case [speaker, speech]:
return cls(identifier=identifier, start=start, end=end, speaker=speaker, speech=speech)
case [speech]:
return cls(identifier=identifier, start=start, end=end, speech=speech)
class WebVTTContent(Content):
"""Represents parsed WebVTT content."""
id: str
metadata: Mapping[str, Any] = Field(default_factory=dict)
start_time: Optional[datetime]
frames: Sequence[WebVTTFrame]
def get_id(self: Self) -> str:
return self.id
def get_chunks(self: Self) -> Sequence[Chunk]:
windows = (tuple(filter(None, window))
for window
in windowed(self.frames, CHUNK_FRAMES_WINDOW, step=CHUNK_FRAMES_WINDOW-CHUNK_FRAMES_OVERLAP))
return tuple(Chunk(text="\n\n".join(": ".join(filter(None, (frame.speaker, frame.speech)))
for frame
in frames),
parent_id=self.get_id(),
chunk_id=f"{frames[0].identifier}-{frames[-1].identifier}",
metadata={
"start": self.start_time + frames[0].start if self.start_time else None,
"end": self.start_time + frames[-1].end if self.start_time else None,
"speakers": tuple(frame.speaker for frame in frames if frame.speaker)
})
for frames
in windows)
def get_metadata(self: Self) -> Mapping[str, Any]:
return MappingProxyType(self.metadata)
@classmethod
def __get_start_time(cls, web_vtt: WebVTT) -> Optional[datetime]:
try:
return next(datetime.fromisoformat(result[0])
for result
in map(ISO_DATE_TIME_PATTERN.findall, web_vtt.header_comments)
if result)
except ValueError:
return None
@classmethod
def from_bytes(cls, id: str, metadata: Mapping[str, Any], buffer: bytes) -> Self:
web_vtt = WebVTT.from_buffer(BytesIO(buffer))
frames = tuple(WebVTTFrame.from_webvtt_caption(caption, index)
for index, caption
in enumerate(web_vtt.captions, 1))
return WebVTTContent(id=id, metadata=MappingProxyType(metadata), start_time=cls.__get_start_time(web_vtt), frames=frames)
|