File size: 2,829 Bytes
e1df50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import List, Optional
from sqlalchemy.orm import Session

from data.models import AnnotationInterval
from utils.logger import Logger

log = Logger()


class AnnotationIntervalRepo:
    """
    Data-access layer for `annotation_intervals` table.
    """

    def __init__(self, db: Session) -> None:
        self.db = db

    # ------------------------------------------------------------------ #
    # CREATE
    # ------------------------------------------------------------------ #
    def assign_interval_to_annotator(
        self,
        annotator_id: int,
        start_idx: int,
        end_idx: int,
        allow_overlap: bool = False,
    ) -> AnnotationInterval:
        """
        Create a new interval [start_idx, end_idx] for the given annotator.

        Raises:
            ValueError:  • start >= end
                         • overlap detected (when allow_overlap=False)
        """
        if start_idx >= end_idx:
            raise ValueError("start_idx must be < end_idx")

        if not allow_overlap:
            overlap = (
                self.db.query(AnnotationInterval)
                .filter(AnnotationInterval.annotator_id == annotator_id)
                .filter(
                    AnnotationInterval.start_index <= end_idx,
                    AnnotationInterval.end_index >= start_idx,
                )
                .first()
            )
            if overlap:
                raise ValueError(
                    f"Overlap with existing interval "
                    f"[{overlap.start_index}, {overlap.end_index}]"
                )

        interval = AnnotationInterval(
            annotator_id=annotator_id,
            start_index=start_idx,
            end_index=end_idx,
        )
        self.db.add(interval)
        self.db.flush()
        self.db.refresh(interval)
        log.info(
            f"Interval [{start_idx}, {end_idx}] assigned to annotator_id={annotator_id}"
        )
        return interval

    # ------------------------------------------------------------------ #
    # READ
    # ------------------------------------------------------------------ #
    def get_intervals_by_annotator(self, annotator_id: int) -> List[AnnotationInterval]:
        return (
            self.db.query(AnnotationInterval)
            .filter(AnnotationInterval.annotator_id == annotator_id)
            .all()
        )

    def get_interval(
        self, annotator_id: int, start_idx: int, end_idx: int
    ) -> Optional[AnnotationInterval]:
        return (
            self.db.query(AnnotationInterval)
            .filter(
                AnnotationInterval.annotator_id == annotator_id,
                AnnotationInterval.start_index == start_idx,
                AnnotationInterval.end_index == end_idx,
            )
            .first()
        )