Spaces:
Runtime error
Runtime error
import json | |
import math | |
import os | |
from typing import List, Optional | |
from transformers.trainer import TRAINER_STATE_NAME | |
from .logging import get_logger | |
from .packages import is_matplotlib_available | |
if is_matplotlib_available(): | |
import matplotlib.pyplot as plt | |
logger = get_logger(__name__) | |
def smooth(scalars: List[float]) -> List[float]: | |
r""" | |
EMA implementation according to TensorBoard. | |
""" | |
last = scalars[0] | |
smoothed = list() | |
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function | |
for next_val in scalars: | |
smoothed_val = last * weight + (1 - weight) * next_val | |
smoothed.append(smoothed_val) | |
last = smoothed_val | |
return smoothed | |
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: | |
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: | |
data = json.load(f) | |
for key in keys: | |
steps, metrics = [], [] | |
for i in range(len(data["log_history"])): | |
if key in data["log_history"][i]: | |
steps.append(data["log_history"][i]["step"]) | |
metrics.append(data["log_history"][i][key]) | |
if len(metrics) == 0: | |
logger.warning(f"No metric {key} to plot.") | |
continue | |
plt.figure() | |
plt.plot(steps, metrics, alpha=0.4, label="original") | |
plt.plot(steps, smooth(metrics), label="smoothed") | |
plt.title("training {} of {}".format(key, save_dictionary)) | |
plt.xlabel("step") | |
plt.ylabel(key) | |
plt.legend() | |
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) | |
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) | |