Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import pprint as pp | |
from langdetect import detect | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from tqdm.auto import tqdm | |
from utils import validate_pytorch2 | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") | |
# 可用模型清單 | |
MODEL_MAP = { | |
"zh": "falconsai/text_summarization", # 中文模型 | |
"en": "facebook/bart-large-cnn", # 英文模型 | |
} | |
LOADED_MODELS = {} | |
def load_model_and_tokenizer(model_name: str): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device).eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logging.info(f"Loaded model: {model_name} to {device}") | |
if validate_pytorch2(): | |
try: | |
logging.info("Compiling model with Torch 2.0") | |
model = torch.compile(model) | |
except Exception as e: | |
logging.warning(f"Could not compile model: {e}") | |
return model, tokenizer | |
def get_model_by_language(text: str): | |
lang = detect(text) | |
model_name = MODEL_MAP.get(lang, MODEL_MAP["en"]) | |
if model_name not in LOADED_MODELS: | |
LOADED_MODELS[model_name] = load_model_and_tokenizer(model_name) | |
return LOADED_MODELS[model_name] | |
def summarize_and_score(ids, mask, model, tokenizer, **kwargs): | |
ids = ids[None, :] | |
mask = mask[None, :] | |
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids | |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask | |
summary_pred_ids = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
output_scores=True, | |
return_dict_in_generate=True, | |
**kwargs, | |
) | |
summary = tokenizer.batch_decode( | |
summary_pred_ids.sequences, | |
skip_special_tokens=True, | |
remove_invalid_values=True, | |
) | |
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4) | |
return summary, score | |
def summarize_via_tokenbatches( | |
input_text: str, | |
model, | |
tokenizer, | |
batch_length=1024, | |
batch_stride=20, | |
min_batch_length=512, | |
**kwargs, | |
): | |
logger = logging.getLogger(__name__) | |
if batch_length < min_batch_length: | |
batch_length = min_batch_length | |
logger.info(f"input parameters:\n{pp.pformat(kwargs)}") | |
logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}") | |
encoded_input = tokenizer( | |
input_text, | |
padding="max_length", | |
truncation=True, | |
max_length=batch_length, | |
stride=batch_stride, | |
return_overflowing_tokens=True, | |
add_special_tokens=False, | |
return_tensors="pt", | |
) | |
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask | |
gen_summaries = [] | |
pbar = tqdm(total=len(in_id_arr)) | |
for _id, _mask in zip(in_id_arr, att_arr): | |
result, score = summarize_and_score( | |
ids=_id, | |
mask=_mask, | |
model=model, | |
tokenizer=tokenizer, | |
**kwargs, | |
) | |
_sum = { | |
"input_tokens": _id, | |
"summary": result, | |
"summary_score": score, | |
} | |
gen_summaries.append(_sum) | |
pbar.update() | |
pbar.close() | |
return gen_summaries | |
def summarize_text(text: str) -> str: | |
model, tokenizer = get_model_by_language(text) | |
summaries = summarize_via_tokenbatches( | |
text, | |
model=model, | |
tokenizer=tokenizer, | |
batch_length=1024, | |
batch_stride=20, | |
max_new_tokens=200, | |
num_beams=4, | |
) | |
return "\n\n".join([s["summary"][0] for s in summaries]) | |