File size: 2,806 Bytes
a60b3fc
 
67436c8
bb5dde5
a60b3fc
6853a4c
a60b3fc
 
bb5dde5
a1a6d79
a60b3fc
af0a2bd
a60b3fc
 
 
bb7c9a3
 
6853a4c
 
a60b3fc
 
 
 
 
 
 
af0a2bd
5c7c7e5
a60b3fc
 
6853a4c
67436c8
a60b3fc
67436c8
a60b3fc
 
 
67436c8
 
a60b3fc
6853a4c
 
 
67436c8
bb7c9a3
af0a2bd
bb5dde5
a60b3fc
67436c8
 
a60b3fc
bb5dde5
67436c8
 
6853a4c
67436c8
 
 
 
 
a60b3fc
af0a2bd
 
bb5dde5
a60b3fc
67436c8
 
a60b3fc
 
67436c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from datetime import datetime, timedelta
from io import BytesIO
from more_itertools import windowed
from pydantic import BaseModel, ConfigDict, Field, field_validator
from types import MappingProxyType
from typing import Any, ClassVar, Literal, Mapping, Optional, Self
from webvtt import Caption, WebVTT

from ctp_slack_bot.utils import to_deep_immutable
from .base import Chunk, Content


class WebVTTFrame(BaseModel):
    """Represents a WebVTT frame"""

    model_config = ConfigDict(frozen=True)

    _SPEAKER_SPEECH_TEXT_SEPARATOR: ClassVar[str] = ": "

    identifier: str
    start: timedelta
    end: timedelta
    speaker: Optional[str] = None
    speech: str

    @classmethod
    def from_webvtt_caption(cls, caption: Caption, index: int) -> Self:
        identifier = caption.identifier if caption.identifier else str(index)
        start = timedelta(**caption.start_time.__dict__)
        end = timedelta(**caption.end_time.__dict__)
        match caption.text.split(cls._SPEAKER_SPEECH_TEXT_SEPARATOR, 1):
            case [speaker, speech]:
                return cls(identifier=identifier, start=start, end=end, speaker=speaker, speech=speech)
            case [speech]:
                return cls(identifier=identifier, start=start, end=end, speech=speech)


class WebVTTContent(Content):
    """Represents parsed WebVTT content."""

    CHUNK_FRAMES_OVERLAP: ClassVar[int] = 1
    CHUNK_FRAMES_WINDOW: ClassVar[int] = 5

    id: str
    metadata: Mapping[str, Any] = Field(default_factory=lambda: MappingProxyType({}))
    start_time: Optional[datetime]
    frames: tuple[WebVTTFrame, ...]

    def get_id(self: Self) -> str:
        return self.id

    def get_chunks(self: Self) -> tuple[Chunk]:
        windows = (tuple(filter(None, window))
                   for window
                   in windowed(self.frames, self.CHUNK_FRAMES_WINDOW, step=self.CHUNK_FRAMES_WINDOW-self.CHUNK_FRAMES_OVERLAP))
        return tuple(Chunk(text="\n\n".join(": ".join(filter(None, (frame.speaker, frame.speech)))
                                            for frame
                                            in frames),
                           parent_id=self.get_id(),
                           chunk_id=f"{frames[0].identifier}-{frames[-1].identifier}",
                           metadata={
                               "start": self.start_time + frames[0].start if self.start_time else None,
                               "end": self.start_time + frames[-1].end if self.start_time else None,
                               "speakers": (frame.speaker for frame in frames if frame.speaker)
                           })
                     for frames
                     in windows)

    def get_metadata(self: Self) -> Mapping[str, Any]:
        return MappingProxyType(self.metadata)