File size: 2,404 Bytes
e1df50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7ef7d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1df50c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7ef7d3
e1df50c
 
 
 
 
 
 
 
f7ef7d3
e1df50c
 
 
f7ef7d3
e1df50c
 
 
f7ef7d3
 
 
e1df50c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import List, Dict, Optional, Any

from sqlalchemy import and_
from sqlalchemy.orm import Session

from data.models import TTSData, Annotation, AnnotationInterval
from data.repository.annotator_repo import AnnotatorRepo
from utils.logger import Logger

log = Logger()


class AnnotatorWorkloadRepo:
    def __init__(self, db: Session) -> None:
        self.db = db
        self.annotator_repo = AnnotatorRepo(db)

    def get_tts_data_with_annotations(
        self, annotator_name: str
    ) -> List[Dict[str, Optional[Any]]]:
        # This method is kept for compatibility if used elsewhere, but
        # get_tts_data_with_annotations_for_user_id is preferred for new logic.
        annotator = self.annotator_repo.get_annotator_by_name(annotator_name)
        if annotator is None:
            log.warning(
                f"Annotator '{annotator_name}' not found in get_tts_data_with_annotations. Returning empty list."
            )
            return []
        return self.get_tts_data_with_annotations_for_user_id(
            annotator.id, annotator_name
        )

    def get_tts_data_with_annotations_for_user_id(
        self, annotator_id: int, annotator_name_for_log: str = "Unknown"
    ) -> List[Dict[str, Optional[Any]]]:
        """
        output:  [
                  {"tts_data": <TTSData>, "annotation": <Annotation or None>},
                  ...
                ]
        """
        query = (
            self.db.query(
                TTSData,
                Annotation,
            )
            .join(
                AnnotationInterval,
                and_(
                    AnnotationInterval.annotator_id == annotator_id,
                    TTSData.id >= AnnotationInterval.start_index,
                    TTSData.id <= AnnotationInterval.end_index,
                ),
            )
            .outerjoin(
                Annotation,
                and_(
                    Annotation.tts_data_id == TTSData.id,
                    Annotation.annotator_id == annotator_id,
                ),
            )
            .order_by(TTSData.id)
        ).distinct(TTSData.id)  # Ensure distinct TTSData items

        rows = [{"tts_data": tts, "annotation": ann} for tts, ann in query.all()]

        log.info(
            f"{len(rows)} TTS rows fetched for annotator ID '{annotator_id}' (Name: {annotator_name_for_log})."
        )
        return rows