File size: 826 Bytes
bec1ee5
6e6a688
4175aca
bec1ee5
 
 
4175aca
 
 
 
 
bec1ee5
4175aca
 
6e6a688
 
 
 
 
 
 
 
 
 
 
 
4175aca
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter

sns.set()

"""
Save plots to the plots folder for when
we would like to show results on our little
flask application
"""


def loss_history_plt(loss_history: list) -> FuncAnimation:
    fig, ax = plt.subplots()

    def animate(i):
        ax.clear()
        sns.lineplot(
            x=range(i),
            y=loss_history[:i],
            ax=ax,
        )
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Training Loss")

    return FuncAnimation(fig, animate, frames=len(loss_history), interval=100)


def save_plt(plot, filename: str, animated: bool, fps=10):
    if not animated:
        plot.savefig(filename)
        return
    writer = FFMpegWriter(fps=fps)
    plot.save(filename, writer=writer)