from collections import Counter from typing import Union from dataclasses import make_dataclass, field from transformers import T5Config import ctypes import os import platform import re import torch from datasketch import MinHash, MinHashLSH from collections import defaultdict from transformers.trainer_callback import TrainerControl, TrainerState from transformers import TrainingArguments, TrainerCallback # from nltk import ngrams from nltk.translate.bleu_score import sentence_bleu import numpy as np import ujson from config import T5ModelConfig # 结束标点符号 END_PUN = set(".。!!))》}】??\"”") class MyTrainerCallback(TrainerCallback): log_cnt = 0 def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): ''' 在打印 n 次日志后清除cuda缓存,适合低显存设备,能防止OOM ''' self.log_cnt += 1 if self.log_cnt % 2 == 0: torch.cuda.empty_cache() def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): ''' 在 on_epoch_end 时保存一次模型。 TrainingArguments的 save_strategy 中 epoch 和 steps 不兼容。要实现每隔 save_steps 步保存一次检查点,考虑到磁盘空间大小,最多只保存最近N个检查点。 ''' # 设置should_save=True并返回即可 control.should_save = True return control # 保留中文和英文、下划线,不要标点符号 NON_CHAR = re.compile("[^[\u4E00-\u9FA5|A-Za-z_0-9]") def _get_doc_mini_hash(doc: list[str] | str, num_perm: int) -> MinHash: ''' 获取一段文本的mini hash ''' mini_hash = MinHash(num_perm=num_perm) for s in doc: mini_hash.update(s.encode('utf-8')) return mini_hash class DropDatasetDuplicate: def __init__(self, threshold: float=0.85, num_perm: int=256) -> None: ''' 获取一个数据集中所有重复(相似的超过threshold)的index,输入为:list[str],一个str元素为一段文本(doc) 如输入: [a, b, c, d, c, d, e] 返回:{4, 5} (后面两个 c, d 的index) ''' self.similar_index_cluster = defaultdict(set) self.data_lsh = MinHashLSH(threshold=threshold, num_perm=num_perm) self.num_perm = num_perm def add_doc(self, index: object, doc: str,) -> set[int]: ''' 添加文档, index: 文档的索引 doc: 文档本身 ''' # 只保留中文和英文、下划线,不要标点符号 doc = ''.join(NON_CHAR.split(doc)) # doc = [''.join(t) for t in list(ngrams(doc, 3))] doc_hash = _get_doc_mini_hash(doc, self.num_perm) close_duplicates = self.data_lsh.query(doc_hash) self.data_lsh.insert(index, doc_hash) # 所有相似的doc在similar_index_cluster中的key都是最早出现的idx # 如:data中索引inndex 2, 7, 8, 9, 10, 12 是相似的,则在similar_index_cluster中表现为 {2: {8, 9, 10, 12}} if len(close_duplicates) > 0: min_idx= min(close_duplicates) self.similar_index_cluster[min_idx].add(index) def get_duplicate_indexs(self): ''' 返回所有的重复文档索引 ''' similar_index_cluster = self.similar_index_cluster need_to_remove_idx = set() for key_idx in similar_index_cluster.keys(): need_to_remove_idx |= similar_index_cluster[key_idx] return need_to_remove_idx def get_T5_config(config: T5ModelConfig, vocab_size: int, decoder_start_token_id: int=0, eos_token_id: int=1) -> T5Config: ''' 用户配置转换为T5Config ''' t5_config = T5Config() # t5_config.model_type = 'TextToTextModel' # 初始化 t5_config.d_ff = config.d_ff t5_config.d_kv = config.d_kv t5_config.d_model = config.d_model t5_config.num_decoder_layers = config.num_decoder_layers t5_config.num_heads = config.num_heads t5_config.num_layers = config.num_layers t5_config.vocab_size = vocab_size t5_config.decoder_start_token_id = decoder_start_token_id t5_config.eos_token_id = eos_token_id return t5_config def f1_p_r_compute(spo_list_pred: list, spo_list_true: list, repair: bool=False): ''' spo_list: [ [(s,p,o)...], [(s,p,o)]], 每一行[(s,p,o)...]为一个句子中的spo 计算spo的f1分数,精确率,召回率, ''' assert len(spo_list_pred) == len(spo_list_true) def repair_song_album(spo_list: list, song: list, album: list): ''' 修复一条文本的'歌曲'和'专辑'的spo。对于歌曲x(subject)的关系歌手、作词、作曲,x必须同时存在于song和album中 ''' if len(song) == 0 and len(album) == 0: return spo_list ps = ['歌手', '作词', '作曲'] new_spo_list = [] for spo in spo_list: s, p = spo[0], spo[1] if p in ps and s in album and s not in song: continue new_spo_list.append(spo) return new_spo_list def repair_song_album_list(spo_list: list): ''' ''' new_spo_list = [] for spos in spo_list: song, album = [], [] for spo in spos: s, p, o = spo if p == '所属专辑': song.append(s) album.append(o) new_spo_list.append(repair_song_album(spos, song, album)) return new_spo_list if repair: spo_list_pred = repair_song_album_list(spo_list_pred) spo_list_true = repair_song_album_list(spo_list_true) TP = 1e-10 # 正类判定为正类, A # TN = 1e-10 # 负类判定为负类 TP_FP = 1e-10 # 检索到的, A + B TP_FN = 1e-10 # 真正想要的,A + C # FP = 1e-10 # 负类判定为正类 # FN = 1e-10 # 正类判定为负类 # p = a / (a + b) # r = a / (a + c) # f1 = 2pr / (p + r) for i in range(len(spo_list_true)): pred_set = set(spo_list_pred[i]) true_set = set(spo_list_true[i]) pred_true_set = pred_set & true_set # 预测和真实取交集 TP += len(pred_true_set) # 检索到且是想要的, A TP_FP += len(pred_set) # 检索到的,包括想要的和不想要的,A + B TP_FN += len(true_set) # 真正想要的, 包括检索到和没检索到的,A + C p = TP / TP_FP r = TP / TP_FN f1 = (2 * p * r) / (p + r) return f1, p, r def fixed_response(item: str) -> str: ''' 修复被截断的回答,从末尾往回找第一个结束标点 ''' if len(item) <= 1: return item if item[-1] in END_PUN: return item n = len(item) i = n - 1 while i > 0 and item[i] not in END_PUN: i -= 1 return ''.join(item[0: i + 1]) def fixed_space(sentence: str)->str: '''单个空格删除,连续两个空格保留一个 ''' n = len(sentence) new_sentence = [] i = 0 while i < n: word = sentence[i] if word != ' ': new_sentence.append(word) elif i + 1 < n and sentence[i + 1] == ' ': new_sentence.append(word) i += 1 # 两个空格保留一个,指针往下走一步 i += 1 return ''.join(new_sentence) def get_free_space_of_disk(folder: str='./') -> float: ''' 获取指定目录所在磁盘大小,返回单位: GB ''' res_val = 0.0 if platform.system() == 'Windows': free_bytes = ctypes.c_ulonglong(0) ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(folder), None, None, ctypes.pointer(free_bytes)) res_val = free_bytes.value else: st = os.statvfs(folder) res_val = st.f_bavail * st.f_frsize return res_val / (1024 ** 3) def my_average(arry_list: list[float]) -> float: ''' 自定义均值计算,空数组返回0.0 ''' if len(arry_list) == 0: return 0.0 return np.average(arry_list) def json_to_dataclass(json_file: str, class_name: str='Config') -> type: ''' 将json配置文件转换为dataclass >>> example: >>> data_class = json_to_dataclass('my_config.json', 'Config') >>> my_config = data_class() >>> assert my_config.name == 'Alice' >>> my_config.name = 'Bob' ''' json_dict = {} with open(json_file, 'r', encoding='utf-8') as f: json_dict = ujson.load(f) # 将dict转换为可迭代的属性名称、属性类型,默认值 fields_list = [] for k, v in json_dict.items(): fields_list.append( (k, type(v), field(default=v)) ) data_class = make_dataclass(cls_name=class_name, fields=fields_list) return data_class def get_path_of_suffix_files(root: str, suffix: str, with_create_time: bool=False) -> list: ''' 获取指定目录下下指定后缀的所有文件的绝对路径 ''' suffix_files = [] for root, _, files in os.walk(root): for file in files: if file.endswith(suffix): full_path = '{}/{}'.format(root, file) if with_create_time: suffix_files.append( (full_path, os.path.getctime(full_path)) ) else: suffix_files.append(full_path) return suffix_files def get_bleu4_score(reference: Union[str, list[str]], outputs: Union[str, list[str]], n_gram: int=4) -> float: ''' 获取bleu4分数 ''' weights = np.ones(n_gram) * (1.0 / n_gram) outputs_len, reference_len = len(outputs), len(reference) if not type(reference) is list: reference = list(reference) if not type(outputs) is list: outputs = list(outputs) outputs_counter = extract_Ngram(outputs, n_gram=n_gram) reference_counter = extract_Ngram(reference, n_gram=n_gram) ngram_counter_clip = outputs_counter & reference_counter clip_counter = np.zeros(n_gram) output_ngram_counter = np.zeros(n_gram) for (key, ngram), cnt in ngram_counter_clip.items(): clip_counter[ngram - 1] += cnt for (key, ngram), cnt in outputs_counter.items(): output_ngram_counter[ngram - 1] += cnt # print(clip_counter, output_ngram_counter) if np.min(clip_counter) == 0.0: return np.array(0.0) precision_scores = clip_counter / output_ngram_counter # bleu log_precision_scores = weights * np.log(precision_scores) # 几何平均形式求平均值然后加权 geometric_mean = np.exp(np.sum(log_precision_scores)) brevity_penalty = np.exp(1 - (reference_len / outputs_len)) # brevity_penalty = 1.0, bleu = sentence_bleu([reference], outputs) # brevity_penalty = 1.0 bleu = brevity_penalty * geometric_mean return bleu def extract_Ngram(words_list: list[str], n_gram: int) -> tuple: ''' 获取一个句子的n_grama return: ngram_counter: key = ('w1 w2 ... wn', n_gram), value: count of key ''' n = len(words_list) ngram_counter = Counter() for i in range(1, n_gram + 1): for j in range(n - i + 1): key = ' '.join(words_list[j: j + i]) ngram_counter[(key, i)] += 1 return ngram_counter def save_model_config(config_dict: dict, file: str) -> None: ''' 将模型配置写入到json文件, 输入模型保存的目录及文件名 ''' # file = file.replace('\\', '/') # file = '{}/model_config.json'.format('/'.join(file.split('/')[0: -1])) with open(file, 'w', encoding='utf-8') as f: ujson.dump(config_dict, f, indent=4, ensure_ascii=False) if __name__ == '__main__': ref = '抱歉,我不知道ABB代表什么意思' out = '我不明白ABB是什么意思' b1 = sentence_bleu([list(out)], list(ref), weights=(0.25, 0.25, 0.25, 0.25)) print(b1) b2 = get_bleu4_score(out, ref) print(b2) candidate_corpus = ['i', 'have', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'c', 'd','f','f'] reference_corpus = ['there', 'is', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'd', 'd', 'fd'] print('----') print(sentence_bleu([reference_corpus], candidate_corpus, weights=(0.25, 0.25, 0.25, 0.25))) print(get_bleu4_score(reference_corpus, candidate_corpus))