import logging import os import pprint as pp from langdetect import detect import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from tqdm.auto import tqdm from utils import validate_pytorch2 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") # 可用模型清單 MODEL_MAP = { "zh": "falconsai/text_summarization", # 中文模型 "en": "facebook/bart-large-cnn", # 英文模型 } LOADED_MODELS = {} def load_model_and_tokenizer(model_name: str): device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device).eval() tokenizer = AutoTokenizer.from_pretrained(model_name) logging.info(f"Loaded model: {model_name} to {device}") if validate_pytorch2(): try: logging.info("Compiling model with Torch 2.0") model = torch.compile(model) except Exception as e: logging.warning(f"Could not compile model: {e}") return model, tokenizer def get_model_by_language(text: str): lang = detect(text) model_name = MODEL_MAP.get(lang, MODEL_MAP["en"]) if model_name not in LOADED_MODELS: LOADED_MODELS[model_name] = load_model_and_tokenizer(model_name) return LOADED_MODELS[model_name] def summarize_and_score(ids, mask, model, tokenizer, **kwargs): ids = ids[None, :] mask = mask[None, :] input_ids = ids.to("cuda") if torch.cuda.is_available() else ids attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask summary_pred_ids = model.generate( input_ids, attention_mask=attention_mask, output_scores=True, return_dict_in_generate=True, **kwargs, ) summary = tokenizer.batch_decode( summary_pred_ids.sequences, skip_special_tokens=True, remove_invalid_values=True, ) score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4) return summary, score def summarize_via_tokenbatches( input_text: str, model, tokenizer, batch_length=1024, batch_stride=20, min_batch_length=512, **kwargs, ): logger = logging.getLogger(__name__) if batch_length < min_batch_length: batch_length = min_batch_length logger.info(f"input parameters:\n{pp.pformat(kwargs)}") logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}") encoded_input = tokenizer( input_text, padding="max_length", truncation=True, max_length=batch_length, stride=batch_stride, return_overflowing_tokens=True, add_special_tokens=False, return_tensors="pt", ) in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask gen_summaries = [] pbar = tqdm(total=len(in_id_arr)) for _id, _mask in zip(in_id_arr, att_arr): result, score = summarize_and_score( ids=_id, mask=_mask, model=model, tokenizer=tokenizer, **kwargs, ) _sum = { "input_tokens": _id, "summary": result, "summary_score": score, } gen_summaries.append(_sum) pbar.update() pbar.close() return gen_summaries def summarize_text(text: str) -> str: model, tokenizer = get_model_by_language(text) summaries = summarize_via_tokenbatches( text, model=model, tokenizer=tokenizer, batch_length=1024, batch_stride=20, max_new_tokens=200, num_beams=4, ) return "\n\n".join([s["summary"][0] for s in summaries])