File size: 1,818 Bytes
dcc5cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pandas as pd
import argparse
import perplexity
from tqdm import tqdm
from joblib import Parallel, delayed
import os


def calculate_doc_perplexity(row_id, doc, doctype, lang):
    id = doc["id"]
    paragraphs = doc["paragraphs"]
    doc_text = "\n".join(para["text"] for para in paragraphs)
    perplexities = perplexity.source_perplexities(doc_text, lang, include_harmful=False)
    return [id, doctype, lang, perplexities]


def main(args):
    file = args.file
    doctype = '_'.join(os.path.basename(file).split('_')[:-1])
    lang = os.path.basename(file).split('_')[-1].split('.')[0]
    chunks = pd.read_json(file, lines=True, chunksize=1000)
    rows = []
    for chunk in tqdm(chunks, desc=f"Processing chunks of {args.file}"):
        rows.extend(Parallel(n_jobs=args.jobs)(
            delayed(calculate_doc_perplexity)(row_id, doc, doctype, lang)
            for row_id, doc in chunk.iterrows()
            if doc["paragraphs"]
        ))
    df = pd.DataFrame(rows, columns=["id", "doc_type", "lang", "perplexities"])

    # Ensure the output directory exists
    os.makedirs(args.output_path, exist_ok=True)

    # Save the DataFrame with the ".jsonl" extension
    output_file = os.path.join(args.output_path, f"{doctype}_{lang}.jsonl")
    df.to_json(output_file, lines=True, orient="records")


if __name__ == "__main__":
    # Set up argument parsing outside of the main() function
    parser = argparse.ArgumentParser(description='Process documents to calculate perplexity.')
    parser.add_argument('file', type=str, help='Input file path')
    parser.add_argument('--output_path', type=str, default='tmp/', help='Output file path')
    parser.add_argument('--jobs', type=int, default=10, help='Number of jobs to use for parallel processing')
    args = parser.parse_args()

    main(args)