File size: 3,118 Bytes
f16bb9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eac013f
 
f16bb9f
 
 
 
 
 
eac013f
 
f16bb9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
File: plots.py
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
Description: Plotting functions.
License: MIT License
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2


def plot_audio(time_axis, waveform, frame_indices, fps, figsize=(10, 4)) -> plt.Figure:
    frame_times = np.array(frame_indices) / fps

    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(time_axis, waveform[0])
    ax.set_xlabel("Time (frames)")
    ax.set_ylabel("Amplitude")
    ax.grid(True)

    ax.set_xticks(frame_times)
    ax.set_xticklabels([f"{int(frame_time * fps) + 1}" for frame_time in frame_times])

    fig.tight_layout()

    return fig


def plot_images(image_paths):
    fig, axes = plt.subplots(1, len(image_paths), figsize=(12, 2))

    for ax, img_path in zip(axes, image_paths):
        ax.imshow(img_path)
        ax.axis("off")

    fig.tight_layout()
    return fig


def get_evenly_spaced_frame_indices(total_frames, num_frames=10):
    if total_frames <= num_frames:
        return list(range(total_frames))

    step = total_frames / num_frames
    return [int(np.round(i * step)) for i in range(num_frames)]


def plot_predictions(
    df: pd.DataFrame,
    column: str,
    title: str,
    y_labels: list[str],
    figsize: tuple[int, int],
    x_ticks: list[int],
    line_width: float = 2.0,
) -> None:
    fig, ax = plt.subplots(figsize=figsize)

    df[column] += 1

    ax.plot(df.index, df[column], linestyle="dotted", linewidth=line_width)
    ax.set_title(title)
    ax.set_xlabel("Frames")
    ax.set_ylabel(title)

    ax.set_xticks(x_ticks)
    ax.set_yticks(range(len(y_labels) + 2))
    ax.set_yticklabels([" "] + y_labels + [" "])

    ax.grid(True)
    fig.tight_layout()
    return fig


def display_frame_info(img, text, margin=1.0, box_scale=1.0, scale=1.5):
    img_copy = img.copy()
    img_h, img_w, _ = img_copy.shape
    line_width = int(min(img_h, img_w) * 0.001)
    thickness = max(int(line_width / 3), 1)

    font_face = cv2.FONT_HERSHEY_SIMPLEX
    font_color = (0, 0, 0)
    font_scale = thickness / scale

    t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]

    margin_n = int(t_h * margin)
    sub_img = img_copy[
        0 + margin_n : 0 + margin_n + t_h + int(2 * t_h * box_scale),
        img_w - t_w - margin_n - int(2 * t_h * box_scale) : img_w - margin_n,
    ]

    white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255

    img_copy[
        0 + margin_n : 0 + margin_n + t_h + int(2 * t_h * box_scale),
        img_w - t_w - margin_n - int(2 * t_h * box_scale) : img_w - margin_n,
    ] = cv2.addWeighted(sub_img, 0.5, white_rect, 0.5, 1.0)

    cv2.putText(
        img=img_copy,
        text=text,
        org=(
            img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
            0 + margin_n + t_h + int(2 * t_h * box_scale) // 2,
        ),
        fontFace=font_face,
        fontScale=font_scale,
        color=font_color,
        thickness=thickness,
        lineType=cv2.LINE_AA,
        bottomLeftOrigin=False,
    )

    return img_copy