File size: 7,116 Bytes
dcc5cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import argparse
import os
from collections import defaultdict
from io import StringIO

import pandas as pd
from tqdm import tqdm

from perplexity import get_model_for
from subsampler import PerplexitySubsampler


def process_files(
    directory,
    reject_level,
    model_override,
    output_file,
    group_by_prefix_lang,
    prefix_lang_mapping=None,
    ratio=None,
    ratio_per_lang=None,
    pa=None,
    pb=None,
    include=None,
):
    if ratio or ratio_per_lang:
        rows = ["doc_type,model,language,reject,bad,medium,good,norm,mean,std"]
    else:
        rows = ["doc_type,model,language,reject,bad,medium,good"]
    files = os.listdir(directory)
    grouped_files = defaultdict(list)
    if prefix_lang_mapping is None:
        prefix_lang_mapping = {}

    # Group files by prefix and language if the option is enabled
    description = "Processing files"
    if group_by_prefix_lang:
        description = "Processing files in groups"
        for file in files:
            parts = file.split('_')
            prefix = parts[0]
            if include and prefix not in include:
                continue
            lang = parts[-1].split(".")[0][:2]
            group_key = prefix_lang_mapping.get(f"{prefix}_{lang}", f"{prefix}_{lang}")
            grouped_files[group_key].append(file)
        file_groups = grouped_files.values()
    else:
        file_groups = []
        for file in files:  # Each file is its own group
            if include and not any(file.startswith(prefix) for prefix in include):
                continue
            file_groups.append([file])

    if output_file:
        progress = tqdm(file_groups, desc=description)
    else:
        progress = file_groups
        print(rows[0])
    # Process each group of files
    for group in progress:
        combined_perplexities = pd.DataFrame()
        doc_type, lang = None, None

        for file in group:
            if not doc_type or not lang:  # Set doc_type and lang based on the first file
                parts = file.split('_')
                doc_type = file.split('_')[0]
                lang = parts[-1].split(".")[0][:2]
                doc_type, lang = prefix_lang_mapping.get(f"{doc_type}_{lang}", f"{doc_type}_{lang}").rsplit("_", 1)
            perp = pd.read_json(os.path.join(directory, file), lines=True)
            perplexities = pd.read_json(StringIO(perp["perplexities"].to_json(lines=True, orient="records")), lines=True)
            combined_perplexities = pd.concat([combined_perplexities, perplexities], ignore_index=True)

        if model_override:
            model = model_override
        else:
            model, _ = get_model_for(doc_type)
        model_with_suffix = f"{model}_pp"

        # Calculate quantiles for the combined perplexities of the group
        reject = round(combined_perplexities[model_with_suffix].quantile(q=reject_level), 2)
        bad = round(combined_perplexities[model_with_suffix].quantile(q=0.75), 2)
        medium = round(combined_perplexities[model_with_suffix].quantile(q=0.50), 2)
        good = round(combined_perplexities[model_with_suffix].quantile(q=0.25), 2)

        if ratio:
            subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values)
            subsampler.set(ratio=ratio, pa=pa, pb=pb)
            norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev
            sampling_stats = f",{norm},{mean},{std}"
        elif ratio_per_lang:
            subsampler = PerplexitySubsampler(combined_perplexities[model_with_suffix].values)
            subsampler.set(ratio=ratio_per_lang.get(lang, ratio or 1.0), pa=pa, pb=pb)
            norm, mean, std = subsampler.norm, subsampler.mean, subsampler.sdev
            sampling_stats = f",{norm},{mean},{std}"
        else:
            sampling_stats = ""

        row = f"{doc_type},{model},{lang},{reject},{bad},{medium},{good}{sampling_stats}"
        if output_file:
            rows.append(row)
        else:
            print(row)


    if output_file:
        with open(output_file, "w") as f:
            for row in rows:
                f.write(f"{row}\n")


def main():
    """"
    Each doc_type prefix needs to have an "no" lang, even of there's no real data.
    These rows are crucial for the rest of the process.
    """
    parser = argparse.ArgumentParser(description="Process files and compute perplexity metrics.")
    parser.add_argument('directory', type=str, help='Directory containing the files to process')
    parser.add_argument('--reject_level', type=float, default=0.95, help='Rejection quantile level (default: 0.95)')
    parser.add_argument('--model_override', type=str, help='Override the model used')
    parser.add_argument('--output_file', type=str, help='Output file in CSV format. If not given, prints to standard output.')
    parser.add_argument('--group_by_prefix_lang', action='store_true', help='Group and calculate quantiles for files with the same prefix and language')
    parser.add_argument('--overwrite_prefix_lang', type=str, help='Overwrite the assignment of languages to doc_type prefixes, e.g., "starcoder_en:starcoder_code,hplt_en:hplt_no"')
    parser.add_argument('--sampling_ratio', type=float, help='Ratio of documents to keep for sampling. If passed, it generate distribution statistics (norm, mean, std) needed for sampling')
    parser.add_argument('--sampling_ratio_per_lang', type=str, help='Ratio of documents per lang, e.g., "en:0.25,sv:0.34"')
    parser.add_argument('--sampling_q1_prob', type=float, default=0.20, help='Probabilty for keeping documents in the Q1 range')
    parser.add_argument('--sampling_q3_prob', type=float, default=0.05, help='Probabilty for keeping documents in the Q3 range')
    parser.add_argument('--include', type=str, help='Comma separeted list of doc type prefixes to include')

    args = parser.parse_args()

    if args.sampling_ratio_per_lang:
        # Turns "en: 0.25, sv : 0.34" into {'en': 0.25, 'sv': 0.34}
        ratio_per_lang = dict(
            (k.strip(), float(v.strip()))
             for k, v in (item.split(":")
             for item in args.sampling_ratio_per_lang.split(",")
            )
        )
    else:
        ratio_per_lang = None
    if args.overwrite_prefix_lang:
        # Turns "starcoder_en:starcoder_code,hplt_en:hplt_no" into {'starcoder_en': 'starcoder_code', 'hplt_en': 'hplt_no'}
        prefix_lang_mapping = dict(
            (k.strip(), v.strip())
             for k, v in (item.split(":")
             for item in args.overwrite_prefix_lang.split(",")
            )
        )
    else:
        prefix_lang_mapping = {}

    process_files(
        args.directory, 
        args.reject_level,
        args.model_override,
        args.output_file,
        group_by_prefix_lang=args.group_by_prefix_lang,
        prefix_lang_mapping=prefix_lang_mapping,
        pa=args.sampling_q1_prob,
        pb=args.sampling_q3_prob,
        ratio=args.sampling_ratio,
        ratio_per_lang=ratio_per_lang,
        include=args.include.split(",") if args.include else None
    )

if __name__ == "__main__":
    main()