Spaces:
Running
Running
File size: 6,058 Bytes
9df4cc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import re
import os
import datasets
from sklearn.metrics import accuracy_score, mean_squared_error
from collections import defaultdict
from rouge_score import rouge_scorer
# 支持 Llama3 的 LoRA 模块定义
lora_module_dict = {
'chatglm2': ['query_key_value'],
'llama2': [
'q_proj', 'k_proj', 'v_proj',
'o_proj', 'gate_proj', 'up_proj', 'down_proj',
],
'llama3': [ # 适配 Llama3-8b 的模块
'q_proj', 'k_proj', 'v_proj',
'o_proj', 'gate_proj', 'up_proj', 'down_proj',
],
}
def tokenize(args, tokenizer, feature):
prompt_ids = tokenizer.encode(
feature['prompt'].strip(), padding=False,
max_length=args.max_length, truncation=True
)
target_ids = tokenizer.encode(
feature['answer'].strip(), padding=False,
max_length=args.max_length, truncation=True, add_special_tokens=False
)
input_ids = prompt_ids + target_ids
exceed_max_length = len(input_ids) >= args.max_length
# Add EOS Token
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length:
input_ids.append(tokenizer.eos_token_id)
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):]
return {
"input_ids": input_ids,
"labels": label_ids,
"exceed_max_length": exceed_max_length
}
def parse_model_name(name, from_remote=False):
if name == 'chatglm2':
return 'THUDM/chatglm2-6b' if from_remote else 'base_models/chatglm2-6b'
elif name == 'llama2':
return 'meta-llama/Llama-2-7b-chat-hf'
elif name == 'llama3':
return 'meta-llama/Llama-3-8B' # 适配 Llama3-8b
else:
raise ValueError(f"Undefined base model {name}")
def load_dataset(names, from_hf_hub=False):
"""
加载数据集,可以从本地或者 Hugging Face Hub 上加载
names: 数据集名称,支持多个数据集逗号分隔
from_hf_hub: 是否从 Hugging Face Hub 上加载数据集
"""
dataset_names = [d for d in names.split(',')]
dataset_list = []
for name in dataset_names:
rep = 1
if from_hf_hub:
# 从 Hugging Face Hub 加载数据集
tmp_dataset = datasets.load_dataset(name)
else:
# 从本地加载数据集(假设是 Arrow 格式的 .arrow 文件)
if os.path.exists(name):
tmp_dataset = datasets.load_from_disk(name) # 本地加载
else:
raise FileNotFoundError(f"Dataset {name} not found in the specified path.")
# 如果数据集中没有 'test' 集,则按照 80/20 比例进行分割
if 'test' not in tmp_dataset:
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42)
dataset_list.extend([tmp_dataset] * rep)
return dataset_list
def parse_answer(answer):
match_res = re.match(
r"^\s*\[Positive Developments\]:\s*(.*)\s*\[Potential Concerns\]:\s*(.*)\s*\[Prediction (&|and) Analysis\]:\s*(.*)\s*$",
answer, flags=re.DOTALL)
if not match_res:
return None
pros, cons, pna = match_res.group(1), match_res.group(2), match_res.group(4)
match_res = re.match(r'^Prediction:\s*(.*)\s*Analysis:\s*(.*)\s*$', pna, flags=re.DOTALL)
if not match_res:
return None
pred, anal = match_res.group(1), match_res.group(2)
if re.search(r'up|increase', pred.lower()):
pred_bin = 1
elif re.search(r'down|decrease|decline', pred.lower()):
pred_bin = -1
else:
pred_bin = 0
match_res = re.search(r'(\d)-(\d)%', pred)
if not match_res:
match_res = re.search(r'(?:more than )?(\d)+?%', pred)
pred_margin = pred_bin * (int(match_res.group(1)) + 0.5) if match_res else 0.
return {
"positive developments": pros,
"potential concerns": cons,
"prediction": pred_margin,
"prediction_binary": pred_bin,
"analysis": anal
}
def calc_rouge_score(references, answers):
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores_per_pair = [scorer.score(ref, ans) for ref, ans in zip(references, answers)]
rouge1 = sum(score['rouge1'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
rouge2 = sum(score['rouge2'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
rougeL = sum(score['rougeL'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL}
def calc_metrics(answers, gts):
answers_dict = defaultdict(list)
gts_dict = defaultdict(list)
for answer, gt in zip(answers, gts):
answer_dict = parse_answer(answer)
gt_dict = parse_answer(gt)
if answer_dict and gt_dict:
for k in answer_dict.keys():
answers_dict[k].append(answer_dict[k])
gts_dict[k].append(gt_dict[k])
if not answers_dict['prediction']:
return {}
bin_acc = accuracy_score(gts_dict['prediction_binary'], answers_dict['prediction_binary'])
mse = mean_squared_error(gts_dict['prediction'], answers_dict['prediction'])
pros_rouge_scores = calc_rouge_score(gts_dict['positive developments'], answers_dict['positive developments'])
cons_rouge_scores = calc_rouge_score(gts_dict['potential concerns'], answers_dict['potential concerns'])
anal_rouge_scores = calc_rouge_score(gts_dict['analysis'], answers_dict['analysis'])
print(f"\nBinary Accuracy: {bin_acc:.2f} | Mean Square Error: {mse:.2f}")
print(f"\nRouge Score of Positive Developments: {pros_rouge_scores}")
print(f"\nRouge Score of Potential Concerns: {cons_rouge_scores}")
print(f"\nRouge Score of Summary Analysis: {anal_rouge_scores}")
return {
"valid_count": len(answers_dict['prediction']),
"bin_acc": bin_acc,
"mse": mse,
"pros_rouge_scores": pros_rouge_scores,
"cons_rouge_scores": cons_rouge_scores,
"anal_rouge_scores": anal_rouge_scores
}
|