DocSummarizer_Jimmy / summarize.py
Jimmy0866's picture
Upload 3 files
04dd521 verified
raw
history blame
3.68 kB
import logging
import os
import pprint as pp
from langdetect import detect
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm.auto import tqdm
from utils import validate_pytorch2
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
# 可用模型清單
MODEL_MAP = {
"zh": "falconsai/text_summarization", # 中文模型
"en": "facebook/bart-large-cnn", # 英文模型
}
LOADED_MODELS = {}
def load_model_and_tokenizer(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.info(f"Loaded model: {model_name} to {device}")
if validate_pytorch2():
try:
logging.info("Compiling model with Torch 2.0")
model = torch.compile(model)
except Exception as e:
logging.warning(f"Could not compile model: {e}")
return model, tokenizer
def get_model_by_language(text: str):
lang = detect(text)
model_name = MODEL_MAP.get(lang, MODEL_MAP["en"])
if model_name not in LOADED_MODELS:
LOADED_MODELS[model_name] = load_model_and_tokenizer(model_name)
return LOADED_MODELS[model_name]
def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
ids = ids[None, :]
mask = mask[None, :]
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
summary = tokenizer.batch_decode(
summary_pred_ids.sequences,
skip_special_tokens=True,
remove_invalid_values=True,
)
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
return summary, score
def summarize_via_tokenbatches(
input_text: str,
model,
tokenizer,
batch_length=1024,
batch_stride=20,
min_batch_length=512,
**kwargs,
):
logger = logging.getLogger(__name__)
if batch_length < min_batch_length:
batch_length = min_batch_length
logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")
encoded_input = tokenizer(
input_text,
padding="max_length",
truncation=True,
max_length=batch_length,
stride=batch_stride,
return_overflowing_tokens=True,
add_special_tokens=False,
return_tensors="pt",
)
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
gen_summaries = []
pbar = tqdm(total=len(in_id_arr))
for _id, _mask in zip(in_id_arr, att_arr):
result, score = summarize_and_score(
ids=_id,
mask=_mask,
model=model,
tokenizer=tokenizer,
**kwargs,
)
_sum = {
"input_tokens": _id,
"summary": result,
"summary_score": score,
}
gen_summaries.append(_sum)
pbar.update()
pbar.close()
return gen_summaries
def summarize_text(text: str) -> str:
model, tokenizer = get_model_by_language(text)
summaries = summarize_via_tokenbatches(
text,
model=model,
tokenizer=tokenizer,
batch_length=1024,
batch_stride=20,
max_new_tokens=200,
num_beams=4,
)
return "\n\n".join([s["summary"][0] for s in summaries])