DocSummarizer_Jimmy / summarize.py
Jimmy0866's picture
Upload 3 files
04dd521 verified
import logging
import os
import pprint as pp
from langdetect import detect
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm.auto import tqdm
from utils import validate_pytorch2
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
# 可用模型清單
MODEL_MAP = {
"zh": "falconsai/text_summarization", # 中文模型
"en": "facebook/bart-large-cnn", # 英文模型
}
LOADED_MODELS = {}
def load_model_and_tokenizer(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.info(f"Loaded model: {model_name} to {device}")
if validate_pytorch2():
try:
logging.info("Compiling model with Torch 2.0")
model = torch.compile(model)
except Exception as e:
logging.warning(f"Could not compile model: {e}")
return model, tokenizer
def get_model_by_language(text: str):
lang = detect(text)
model_name = MODEL_MAP.get(lang, MODEL_MAP["en"])
if model_name not in LOADED_MODELS:
LOADED_MODELS[model_name] = load_model_and_tokenizer(model_name)
return LOADED_MODELS[model_name]
def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
ids = ids[None, :]
mask = mask[None, :]
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
summary = tokenizer.batch_decode(
summary_pred_ids.sequences,
skip_special_tokens=True,
remove_invalid_values=True,
)
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
return summary, score
def summarize_via_tokenbatches(
input_text: str,
model,
tokenizer,
batch_length=1024,
batch_stride=20,
min_batch_length=512,
**kwargs,
):
logger = logging.getLogger(__name__)
if batch_length < min_batch_length:
batch_length = min_batch_length
logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")
encoded_input = tokenizer(
input_text,
padding="max_length",
truncation=True,
max_length=batch_length,
stride=batch_stride,
return_overflowing_tokens=True,
add_special_tokens=False,
return_tensors="pt",
)
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
gen_summaries = []
pbar = tqdm(total=len(in_id_arr))
for _id, _mask in zip(in_id_arr, att_arr):
result, score = summarize_and_score(
ids=_id,
mask=_mask,
model=model,
tokenizer=tokenizer,
**kwargs,
)
_sum = {
"input_tokens": _id,
"summary": result,
"summary_score": score,
}
gen_summaries.append(_sum)
pbar.update()
pbar.close()
return gen_summaries
def summarize_text(text: str) -> str:
model, tokenizer = get_model_by_language(text)
summaries = summarize_via_tokenbatches(
text,
model=model,
tokenizer=tokenizer,
batch_length=1024,
batch_stride=20,
max_new_tokens=200,
num_beams=4,
)
return "\n\n".join([s["summary"][0] for s in summaries])