File size: 3,678 Bytes
04dd521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import os
import pprint as pp
from langdetect import detect

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm.auto import tqdm

from utils import validate_pytorch2

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")

# 可用模型清單
MODEL_MAP = {
    "zh": "falconsai/text_summarization",        # 中文模型
    "en": "facebook/bart-large-cnn",             # 英文模型
}

LOADED_MODELS = {}

def load_model_and_tokenizer(model_name: str):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    logging.info(f"Loaded model: {model_name} to {device}")
    if validate_pytorch2():
        try:
            logging.info("Compiling model with Torch 2.0")
            model = torch.compile(model)
        except Exception as e:
            logging.warning(f"Could not compile model: {e}")
    return model, tokenizer

def get_model_by_language(text: str):
    lang = detect(text)
    model_name = MODEL_MAP.get(lang, MODEL_MAP["en"])
    if model_name not in LOADED_MODELS:
        LOADED_MODELS[model_name] = load_model_and_tokenizer(model_name)
    return LOADED_MODELS[model_name]

def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
    ids = ids[None, :]
    mask = mask[None, :]

    input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
    attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask

    summary_pred_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        output_scores=True,
        return_dict_in_generate=True,
        **kwargs,
    )

    summary = tokenizer.batch_decode(
        summary_pred_ids.sequences,
        skip_special_tokens=True,
        remove_invalid_values=True,
    )
    score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
    return summary, score

def summarize_via_tokenbatches(
    input_text: str,
    model,
    tokenizer,
    batch_length=1024,
    batch_stride=20,
    min_batch_length=512,
    **kwargs,
):
    logger = logging.getLogger(__name__)
    if batch_length < min_batch_length:
        batch_length = min_batch_length

    logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
    logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")

    encoded_input = tokenizer(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=batch_length,
        stride=batch_stride,
        return_overflowing_tokens=True,
        add_special_tokens=False,
        return_tensors="pt",
    )

    in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
    gen_summaries = []
    pbar = tqdm(total=len(in_id_arr))

    for _id, _mask in zip(in_id_arr, att_arr):
        result, score = summarize_and_score(
            ids=_id,
            mask=_mask,
            model=model,
            tokenizer=tokenizer,
            **kwargs,
        )
        _sum = {
            "input_tokens": _id,
            "summary": result,
            "summary_score": score,
        }
        gen_summaries.append(_sum)
        pbar.update()

    pbar.close()
    return gen_summaries

def summarize_text(text: str) -> str:
    model, tokenizer = get_model_by_language(text)
    summaries = summarize_via_tokenbatches(
        text,
        model=model,
        tokenizer=tokenizer,
        batch_length=1024,
        batch_stride=20,
        max_new_tokens=200,
        num_beams=4,
    )
    return "\n\n".join([s["summary"][0] for s in summaries])