--- license: cc-by-nc-sa-4.0 language: - zh - vi --- [WN-zh-vi-sim-v0.3](https://huggingface.co/CjangCjengh/WN-zh-vi-sim-v0.3)的GPTQ Int4量化版本 Bản quant GPTQ Int4 của [WN-zh-vi-sim-v0.3](https://huggingface.co/CjangCjengh/WN-zh-vi-sim-v0.3) 模型用于对齐中文文本和越南语文本 Mô hình dùng để align văn bản tiếng Trung và tiếng Việt ```python import os import json import torch import torch.nn.functional as F from vllm import LLM from vllm.config import PoolerConfig from huggingface_hub import hf_hub_download model_path = 'CjangCjengh/WN-zh-vi-sim-v0.3-GPTQ-Int4' zh_text_path = 'zh.txt' vi_text_path = 'vi.txt' output_path = 'output.json' save_interval = 50 device = 'cuda' cpu_offload_gb = 0 lm_head_filename = 'yes_no_lm_head.pt' lm_head_path = hf_hub_download(repo_id=model_path, filename=lm_head_filename, local_dir='.') zh_idx = 0 vi_idx = 0 max_extra_lines = 5 align_list = [] if os.path.exists(output_path): with open(output_path,'r',encoding='utf-8') as f: align_list = json.load(f) zh_idx = sum([i['zh'].count('\n')+1 for i in align_list if i['zh']]) vi_idx = sum([i['vi'].count('\n')+1 for i in align_list if i['vi']]) lm_head = torch.load(lm_head_path) lm_head.to(device) llm = LLM(model=model_path, cpu_offload_gb=cpu_offload_gb, enforce_eager=True, task='embed', override_pooler_config=PoolerConfig(pooling_type='ALL')) zh_lines = open(zh_text_path,'r',encoding='utf-8').readlines() vi_lines = open(vi_text_path,'r',encoding='utf-8').readlines() zh_lines = [l.strip() for l in zh_lines] vi_lines = [l.strip() for l in vi_lines] def get_sim_score(src_text, tgt_text): text = f'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n下面的中文段落和越南语段落的内容是否完全对应,不存在缺漏?(回答Yes或No)\n\n中文:\n{src_text}\n\n越南语:{tgt_text}<|im_end|>\n<|im_start|>assistant\n' outputs = llm.encode(text) with torch.inference_mode(): hidden_states = outputs[0].outputs.data[-1].to(dtype=lm_head.dtype).to(device) logits = torch.matmul(lm_head, hidden_states) result = F.softmax(logits, dim=0).tolist() return result[0] def generate_pairs(): visited = set() size = 1 while True: for x in range(size + 1): for y in range(size + 1): if (x, y) not in visited: visited.add((x, y)) yield (x+1, y+1) size += 1 while zh_idx < len(zh_lines) and vi_idx < len(vi_lines): for zh_i, vi_i in generate_pairs(): zh_text = ''.join(zh_lines[zh_idx:zh_idx+zh_i]) vi_text = ' '.join(vi_lines[vi_idx:vi_idx+vi_i]) score = get_sim_score(zh_text, vi_text) if score > 0.5: break if zh_i == 1 and vi_i == 1: continue score = get_sim_score(zh_lines[zh_idx+zh_i-1], vi_lines[vi_idx+vi_i-1]) if score > 0.5: zh_i -= 1 vi_i -= 1 break if zh_i > max_extra_lines or vi_i > max_extra_lines: end_flag = False for zh_i, vi_i in generate_pairs(): if zh_i == 1 and vi_i == 1: continue for zh_i_offset, vi_i_offset in generate_pairs(): zh_start = zh_idx+zh_i-1 vi_start = vi_idx+vi_i-1 if zh_i+zh_i_offset > max_extra_lines and vi_i+vi_i_offset > max_extra_lines: break if zh_i+zh_i_offset > max_extra_lines or vi_i+vi_i_offset > max_extra_lines: continue zh_text = ''.join(zh_lines[zh_start:zh_start+zh_i_offset]) vi_text = ' '.join(vi_lines[vi_start:vi_start+vi_i_offset]) score = get_sim_score(zh_text, vi_text) if score > 0.5: zh_i -= 1 vi_i -= 1 end_flag = True break if end_flag: break if zh_i > max_extra_lines or vi_i > max_extra_lines: with open(output_path,'w',encoding='utf-8') as f: json.dump(align_list, f, ensure_ascii=False, indent=0) raise Exception(f'Error! zh line No.{zh_idx+1} vi line No.{vi_idx+1}') zh_text = '\n'.join(zh_lines[zh_idx:zh_idx+zh_i]) vi_text = '\n'.join(vi_lines[vi_idx:vi_idx+vi_i]) align_list.append({'zh':zh_text, 'vi':vi_text}) new_align = [list(range(zh_idx, zh_idx+zh_i)), list(range(vi_idx, vi_idx+vi_i))] print(new_align) if len(align_list) % save_interval == 0: with open(output_path,'w',encoding='utf-8') as f: json.dump(align_list, f, ensure_ascii=False, indent=0) zh_idx += zh_i vi_idx += vi_i with open(output_path,'w',encoding='utf-8') as f: json.dump(align_list, f, ensure_ascii=False, indent=0) ```