LiKenun's picture
MIME type handler proof-of-concept
d9e81c2
raw
history blame
2.75 kB
from datetime import datetime, timedelta
from io import BytesIO
from itertools import starmap
from json import dumps
from more_itertools import windowed
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Self, Sequence
from webvtt import Caption, WebVTT
from ctp_slack_bot.models.base import Chunk, Content
CHUNK_FRAMES_OVERLAP = 1
CHUNK_FRAMES_WINDOW = 5
SPEAKER_SPEECH_TEXT_SEPARATOR = ": "
class WebVTTFrame(BaseModel):
"""Represents a WebVTT frame"""
identifier: str
start: timedelta
end: timedelta
speaker: Optional[str] = None
speech: str
model_config = ConfigDict(frozen=True)
@classmethod
def from_webvtt_caption(cls, caption: Caption, index: int) -> Self:
identifier = caption.identifier if caption.identifier else str(index)
start = timedelta(**caption.start_time.__dict__)
end = timedelta(**caption.end_time.__dict__)
match caption.text.split(SPEAKER_SPEECH_TEXT_SEPARATOR, 1):
case [speaker, speech]:
return cls(identifier=identifier, start=start, end=end, speaker=speaker, speech=speech)
case [speech]:
return cls(identifier=identifier, start=start, end=end, speech=speech)
class WebVTTContent(Content):
"""Represents parsed WebVTT content."""
id: str
metadata: Mapping[str, Any] = Field(default_factory=dict)
start_time: Optional[datetime]
frames: Sequence[WebVTTFrame]
def get_id(self: Self) -> str:
return self.id
def get_chunks(self: Self) -> Sequence[Chunk]:
windows = (tuple(filter(None, window))
for window
in windowed(self.frames, CHUNK_FRAMES_WINDOW, step=CHUNK_FRAMES_WINDOW-CHUNK_FRAMES_OVERLAP))
return tuple(Chunk(text="\n\n".join(": ".join(filter(None, (frame.speaker, frame.speech)))
for frame
in frames),
parent_id=self.get_id(),
chunk_id=f"{frames[0].identifier}-{frames[-1].identifier}",
metadata={
"start": self.start_time + frames[0].start if self.start_time else None,
"end": self.start_time + frames[-1].end if self.start_time else None,
"speakers": tuple(frame.speaker for frame in frames if frame.speaker)
})
for frames
in windows)
def get_metadata(self: Self) -> Mapping[str, Any]:
return MappingProxyType(self.metadata)