File size: 3,727 Bytes
a60b3fc
 
5c7c7e5
a60b3fc
67436c8
 
af0a2bd
a60b3fc
 
 
 
 
 
af0a2bd
67436c8
 
 
af0a2bd
 
a60b3fc
 
 
 
 
 
 
 
 
 
 
 
 
af0a2bd
5c7c7e5
a60b3fc
 
67436c8
 
a60b3fc
67436c8
a60b3fc
 
 
67436c8
 
a60b3fc
67436c8
 
af0a2bd
67436c8
a60b3fc
67436c8
 
a60b3fc
 
67436c8
 
 
 
 
 
 
 
a60b3fc
af0a2bd
 
 
a60b3fc
67436c8
 
a60b3fc
 
67436c8
a60b3fc
67436c8
af0a2bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from datetime import datetime, timedelta
from io import BytesIO
from itertools import starmap
from json import dumps
from more_itertools import windowed
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
from re import compile as compile_re
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Self, Sequence
from webvtt import Caption, WebVTT

from ctp_slack_bot.models.base import Chunk, Content


CHUNK_FRAMES_OVERLAP = 1
CHUNK_FRAMES_WINDOW = 5
SPEAKER_SPEECH_TEXT_SEPARATOR = ": "
ISO_DATE_TIME_PATTERN = compile_re(r"Start time: (\d{4}-\d{2}-\d{2}(?: \d{2}:\d{2}:\d{2}(?:Z|[+-]\d{2}:\d{2})?)?)")


class WebVTTFrame(BaseModel):
    """Represents a WebVTT frame"""

    identifier: str
    start: timedelta
    end: timedelta
    speaker: Optional[str] = None
    speech: str

    model_config = ConfigDict(frozen=True)

    @classmethod
    def from_webvtt_caption(cls, caption: Caption, index: int) -> Self:
        identifier = caption.identifier if caption.identifier else str(index)
        start = timedelta(**caption.start_time.__dict__)
        end = timedelta(**caption.end_time.__dict__)
        match caption.text.split(SPEAKER_SPEECH_TEXT_SEPARATOR, 1):
            case [speaker, speech]:
                return cls(identifier=identifier, start=start, end=end, speaker=speaker, speech=speech)
            case [speech]:
                return cls(identifier=identifier, start=start, end=end, speech=speech)


class WebVTTContent(Content):
    """Represents parsed WebVTT content."""

    id: str
    metadata: Mapping[str, Any] = Field(default_factory=dict)
    start_time: Optional[datetime]
    frames: Sequence[WebVTTFrame]

    def get_id(self: Self) -> str:
        return self.id

    def get_chunks(self: Self) -> Sequence[Chunk]:
        windows = (tuple(filter(None, window))
                   for window
                   in windowed(self.frames, CHUNK_FRAMES_WINDOW, step=CHUNK_FRAMES_WINDOW-CHUNK_FRAMES_OVERLAP))
        return tuple(Chunk(text="\n\n".join(": ".join(filter(None, (frame.speaker, frame.speech)))
                                            for frame
                                            in frames),
                           parent_id=self.get_id(),
                           chunk_id=f"{frames[0].identifier}-{frames[-1].identifier}",
                           metadata={
                               "start": self.start_time + frames[0].start if self.start_time else None,
                               "end": self.start_time + frames[-1].end if self.start_time else None,
                               "speakers": tuple(frame.speaker for frame in frames if frame.speaker)
                           })
                     for frames
                     in windows)

    def get_metadata(self: Self) -> Mapping[str, Any]:
        return MappingProxyType(self.metadata)

    @classmethod
    def __get_start_time(cls, web_vtt: WebVTT) -> Optional[datetime]:
        try:
            return next(datetime.fromisoformat(result[0])
                        for result
                        in map(ISO_DATE_TIME_PATTERN.findall, web_vtt.header_comments)
                        if result)
        except ValueError:
            return None

    @classmethod
    def from_bytes(cls, id: str, metadata: Mapping[str, Any], buffer: bytes) -> Self:
        web_vtt = WebVTT.from_buffer(BytesIO(buffer))
        frames = tuple(WebVTTFrame.from_webvtt_caption(caption, index)
                       for index, caption
                       in enumerate(web_vtt.captions, 1))
        return WebVTTContent(id=id, metadata=MappingProxyType(metadata), start_time=cls.__get_start_time(web_vtt), frames=frames)