File size: 2,231 Bytes
40e38d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import re
import heapq
from collections import defaultdict
import tempfile
from typing import Dict, Tuple, List, Literal
import gradio as gr
from datatrove.utils.stats import MetricStatsDict

PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]

def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]:
    metrics_rounded = defaultdict(lambda: 0)
    for key, value in metric.items():
        metrics_rounded[round(float(key), rounding)] += value.total
    if normalization:
        normalizer = sum(metrics_rounded.values())
        metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()}
        assert abs(sum(metrics_rounded.values()) - 1) < 0.01
    return metrics_rounded

def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]:
    regex_compiled = re.compile(regex) if regex else None
    metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
    means = {key: round(float(value.mean), rounding) for key, value in metric.items()}
    if direction == "Top":
        keys = heapq.nlargest(top_k, means, key=means.get)
    elif direction == "Most frequent (n_docs)":
        totals = {key: int(value.n) for key, value in metric.items()}
        keys = heapq.nlargest(top_k, totals, key=totals.get)
    else:
        keys = heapq.nsmallest(top_k, means, key=means.get)

    means = [means[key] for key in keys]
    stds = [metric[key].standard_deviation for key in keys]
    return keys, means, stds

def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str):
    if not exported_data:
        return None
    with tempfile.NamedTemporaryFile(mode="w", delete=False, prefix=metric_name, suffix=".json") as temp:
        json.dump({
            name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"])
            for name, dt in exported_data.items()
        }, temp, indent=2)
        temp_path = temp.name
    return gr.update(visible=True, value=temp_path)