chatmlTest / utils /functions.py
fangshengren's picture
Upload 59 files
f4fac26 verified
from collections import Counter
from typing import Union
from dataclasses import make_dataclass, field
from transformers import T5Config
import ctypes
import os
import platform
import re
import torch
from datasketch import MinHash, MinHashLSH
from collections import defaultdict
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers import TrainingArguments, TrainerCallback
# from nltk import ngrams
from nltk.translate.bleu_score import sentence_bleu
import numpy as np
import ujson
from config import T5ModelConfig
# 结束标点符号
END_PUN = set(".。!!))》}】??\"”")
class MyTrainerCallback(TrainerCallback):
log_cnt = 0
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
'''
在打印 n 次日志后清除cuda缓存,适合低显存设备,能防止OOM
'''
self.log_cnt += 1
if self.log_cnt % 2 == 0:
torch.cuda.empty_cache()
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
'''
在 on_epoch_end 时保存一次模型。
TrainingArguments的 save_strategy 中 epoch 和 steps 不兼容。要实现每隔 save_steps 步保存一次检查点,考虑到磁盘空间大小,最多只保存最近N个检查点。
'''
# 设置should_save=True并返回即可
control.should_save = True
return control
# 保留中文和英文、下划线,不要标点符号
NON_CHAR = re.compile("[^[\u4E00-\u9FA5|A-Za-z_0-9]")
def _get_doc_mini_hash(doc: list[str] | str, num_perm: int) -> MinHash:
'''
获取一段文本的mini hash
'''
mini_hash = MinHash(num_perm=num_perm)
for s in doc:
mini_hash.update(s.encode('utf-8'))
return mini_hash
class DropDatasetDuplicate:
def __init__(self, threshold: float=0.85, num_perm: int=256) -> None:
'''
获取一个数据集中所有重复(相似的超过threshold)的index,输入为:list[str],一个str元素为一段文本(doc)
如输入: [a, b, c, d, c, d, e] 返回:{4, 5} (后面两个 c, d 的index)
'''
self.similar_index_cluster = defaultdict(set)
self.data_lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
self.num_perm = num_perm
def add_doc(self, index: object, doc: str,) -> set[int]:
'''
添加文档,
index: 文档的索引
doc: 文档本身
'''
# 只保留中文和英文、下划线,不要标点符号
doc = ''.join(NON_CHAR.split(doc))
# doc = [''.join(t) for t in list(ngrams(doc, 3))]
doc_hash = _get_doc_mini_hash(doc, self.num_perm)
close_duplicates = self.data_lsh.query(doc_hash)
self.data_lsh.insert(index, doc_hash)
# 所有相似的doc在similar_index_cluster中的key都是最早出现的idx
# 如:data中索引inndex 2, 7, 8, 9, 10, 12 是相似的,则在similar_index_cluster中表现为 {2: {8, 9, 10, 12}}
if len(close_duplicates) > 0:
min_idx= min(close_duplicates)
self.similar_index_cluster[min_idx].add(index)
def get_duplicate_indexs(self):
'''
返回所有的重复文档索引
'''
similar_index_cluster = self.similar_index_cluster
need_to_remove_idx = set()
for key_idx in similar_index_cluster.keys():
need_to_remove_idx |= similar_index_cluster[key_idx]
return need_to_remove_idx
def get_T5_config(config: T5ModelConfig, vocab_size: int, decoder_start_token_id: int=0, eos_token_id: int=1) -> T5Config:
'''
用户配置转换为T5Config
'''
t5_config = T5Config()
# t5_config.model_type = 'TextToTextModel'
# 初始化
t5_config.d_ff = config.d_ff
t5_config.d_kv = config.d_kv
t5_config.d_model = config.d_model
t5_config.num_decoder_layers = config.num_decoder_layers
t5_config.num_heads = config.num_heads
t5_config.num_layers = config.num_layers
t5_config.vocab_size = vocab_size
t5_config.decoder_start_token_id = decoder_start_token_id
t5_config.eos_token_id = eos_token_id
return t5_config
def f1_p_r_compute(spo_list_pred: list, spo_list_true: list, repair: bool=False):
'''
spo_list: [ [(s,p,o)...], [(s,p,o)]], 每一行[(s,p,o)...]为一个句子中的spo
计算spo的f1分数,精确率,召回率,
'''
assert len(spo_list_pred) == len(spo_list_true)
def repair_song_album(spo_list: list, song: list, album: list):
'''
修复一条文本的'歌曲'和'专辑'的spo。对于歌曲x(subject)的关系歌手、作词、作曲,x必须同时存在于song和album中
'''
if len(song) == 0 and len(album) == 0:
return spo_list
ps = ['歌手', '作词', '作曲']
new_spo_list = []
for spo in spo_list:
s, p = spo[0], spo[1]
if p in ps and s in album and s not in song:
continue
new_spo_list.append(spo)
return new_spo_list
def repair_song_album_list(spo_list: list):
'''
'''
new_spo_list = []
for spos in spo_list:
song, album = [], []
for spo in spos:
s, p, o = spo
if p == '所属专辑':
song.append(s)
album.append(o)
new_spo_list.append(repair_song_album(spos, song, album))
return new_spo_list
if repair:
spo_list_pred = repair_song_album_list(spo_list_pred)
spo_list_true = repair_song_album_list(spo_list_true)
TP = 1e-10 # 正类判定为正类, A
# TN = 1e-10 # 负类判定为负类
TP_FP = 1e-10 # 检索到的, A + B
TP_FN = 1e-10 # 真正想要的,A + C
# FP = 1e-10 # 负类判定为正类
# FN = 1e-10 # 正类判定为负类
# p = a / (a + b)
# r = a / (a + c)
# f1 = 2pr / (p + r)
for i in range(len(spo_list_true)):
pred_set = set(spo_list_pred[i])
true_set = set(spo_list_true[i])
pred_true_set = pred_set & true_set # 预测和真实取交集
TP += len(pred_true_set) # 检索到且是想要的, A
TP_FP += len(pred_set) # 检索到的,包括想要的和不想要的,A + B
TP_FN += len(true_set) # 真正想要的, 包括检索到和没检索到的,A + C
p = TP / TP_FP
r = TP / TP_FN
f1 = (2 * p * r) / (p + r)
return f1, p, r
def fixed_response(item: str) -> str:
'''
修复被截断的回答,从末尾往回找第一个结束标点
'''
if len(item) <= 1: return item
if item[-1] in END_PUN: return item
n = len(item)
i = n - 1
while i > 0 and item[i] not in END_PUN:
i -= 1
return ''.join(item[0: i + 1])
def fixed_space(sentence: str)->str:
'''单个空格删除,连续两个空格保留一个
'''
n = len(sentence)
new_sentence = []
i = 0
while i < n:
word = sentence[i]
if word != ' ':
new_sentence.append(word)
elif i + 1 < n and sentence[i + 1] == ' ':
new_sentence.append(word)
i += 1 # 两个空格保留一个,指针往下走一步
i += 1
return ''.join(new_sentence)
def get_free_space_of_disk(folder: str='./') -> float:
'''
获取指定目录所在磁盘大小,返回单位: GB
'''
res_val = 0.0
if platform.system() == 'Windows':
free_bytes = ctypes.c_ulonglong(0)
ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(folder), None, None, ctypes.pointer(free_bytes))
res_val = free_bytes.value
else:
st = os.statvfs(folder)
res_val = st.f_bavail * st.f_frsize
return res_val / (1024 ** 3)
def my_average(arry_list: list[float]) -> float:
'''
自定义均值计算,空数组返回0.0
'''
if len(arry_list) == 0: return 0.0
return np.average(arry_list)
def json_to_dataclass(json_file: str, class_name: str='Config') -> type:
'''
将json配置文件转换为dataclass
>>> example:
>>> data_class = json_to_dataclass('my_config.json', 'Config')
>>> my_config = data_class()
>>> assert my_config.name == 'Alice'
>>> my_config.name = 'Bob'
'''
json_dict = {}
with open(json_file, 'r', encoding='utf-8') as f:
json_dict = ujson.load(f)
# 将dict转换为可迭代的属性名称、属性类型,默认值
fields_list = []
for k, v in json_dict.items():
fields_list.append( (k, type(v), field(default=v)) )
data_class = make_dataclass(cls_name=class_name, fields=fields_list)
return data_class
def get_path_of_suffix_files(root: str, suffix: str, with_create_time: bool=False) -> list:
'''
获取指定目录下下指定后缀的所有文件的绝对路径
'''
suffix_files = []
for root, _, files in os.walk(root):
for file in files:
if file.endswith(suffix):
full_path = '{}/{}'.format(root, file)
if with_create_time:
suffix_files.append( (full_path, os.path.getctime(full_path)) )
else:
suffix_files.append(full_path)
return suffix_files
def get_bleu4_score(reference: Union[str, list[str]], outputs: Union[str, list[str]], n_gram: int=4) -> float:
'''
获取bleu4分数
'''
weights = np.ones(n_gram) * (1.0 / n_gram)
outputs_len, reference_len = len(outputs), len(reference)
if not type(reference) is list:
reference = list(reference)
if not type(outputs) is list:
outputs = list(outputs)
outputs_counter = extract_Ngram(outputs, n_gram=n_gram)
reference_counter = extract_Ngram(reference, n_gram=n_gram)
ngram_counter_clip = outputs_counter & reference_counter
clip_counter = np.zeros(n_gram)
output_ngram_counter = np.zeros(n_gram)
for (key, ngram), cnt in ngram_counter_clip.items():
clip_counter[ngram - 1] += cnt
for (key, ngram), cnt in outputs_counter.items():
output_ngram_counter[ngram - 1] += cnt
# print(clip_counter, output_ngram_counter)
if np.min(clip_counter) == 0.0:
return np.array(0.0)
precision_scores = clip_counter / output_ngram_counter
# bleu
log_precision_scores = weights * np.log(precision_scores)
# 几何平均形式求平均值然后加权
geometric_mean = np.exp(np.sum(log_precision_scores))
brevity_penalty = np.exp(1 - (reference_len / outputs_len))
# brevity_penalty = 1.0, bleu = sentence_bleu([reference], outputs)
# brevity_penalty = 1.0
bleu = brevity_penalty * geometric_mean
return bleu
def extract_Ngram(words_list: list[str], n_gram: int) -> tuple:
'''
获取一个句子的n_grama
return:
ngram_counter: key = ('w1 w2 ... wn', n_gram), value: count of key
'''
n = len(words_list)
ngram_counter = Counter()
for i in range(1, n_gram + 1):
for j in range(n - i + 1):
key = ' '.join(words_list[j: j + i])
ngram_counter[(key, i)] += 1
return ngram_counter
def save_model_config(config_dict: dict, file: str) -> None:
'''
将模型配置写入到json文件, 输入模型保存的目录及文件名
'''
# file = file.replace('\\', '/')
# file = '{}/model_config.json'.format('/'.join(file.split('/')[0: -1]))
with open(file, 'w', encoding='utf-8') as f:
ujson.dump(config_dict, f, indent=4, ensure_ascii=False)
if __name__ == '__main__':
ref = '抱歉,我不知道ABB代表什么意思'
out = '我不明白ABB是什么意思'
b1 = sentence_bleu([list(out)], list(ref), weights=(0.25, 0.25, 0.25, 0.25))
print(b1)
b2 = get_bleu4_score(out, ref)
print(b2)
candidate_corpus = ['i', 'have', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'c', 'd','f','f']
reference_corpus = ['there', 'is', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'd', 'd', 'fd']
print('----')
print(sentence_bleu([reference_corpus], candidate_corpus, weights=(0.25, 0.25, 0.25, 0.25)))
print(get_bleu4_score(reference_corpus, candidate_corpus))