WN-zh-vi-sim-v0.3的GPTQ Int4量化版本
Bản quant GPTQ Int4 của WN-zh-vi-sim-v0.3
模型用于对齐中文文本和越南语文本
Mô hình dùng để align văn bản tiếng Trung và tiếng Việt
import os
import json
import torch
import torch.nn.functional as F
from vllm import LLM
from vllm.config import PoolerConfig
from huggingface_hub import hf_hub_download
model_path = 'CjangCjengh/WN-zh-vi-sim-v0.3-GPTQ-Int4'
zh_text_path = 'zh.txt'
vi_text_path = 'vi.txt'
output_path = 'output.json'
save_interval = 50
device = 'cuda'
cpu_offload_gb = 0
lm_head_filename = 'yes_no_lm_head.pt'
lm_head_path = hf_hub_download(repo_id=model_path, filename=lm_head_filename, local_dir='.')
zh_idx = 0
vi_idx = 0
max_extra_lines = 5
align_list = []
if os.path.exists(output_path):
with open(output_path,'r',encoding='utf-8') as f:
align_list = json.load(f)
zh_idx = sum([i['zh'].count('\n')+1 for i in align_list if i['zh']])
vi_idx = sum([i['vi'].count('\n')+1 for i in align_list if i['vi']])
lm_head = torch.load(lm_head_path)
lm_head.to(device)
llm = LLM(model=model_path, cpu_offload_gb=cpu_offload_gb, enforce_eager=True, task='embed', override_pooler_config=PoolerConfig(pooling_type='ALL'))
zh_lines = open(zh_text_path,'r',encoding='utf-8').readlines()
vi_lines = open(vi_text_path,'r',encoding='utf-8').readlines()
zh_lines = [l.strip() for l in zh_lines]
vi_lines = [l.strip() for l in vi_lines]
def get_sim_score(src_text, tgt_text):
text = f'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n下面的中文段落和越南语段落的内容是否完全对应,不存在缺漏?(回答Yes或No)\n\n中文:\n{src_text}\n\n越南语:{tgt_text}<|im_end|>\n<|im_start|>assistant\n'
outputs = llm.encode(text)
with torch.inference_mode():
hidden_states = outputs[0].outputs.data[-1].to(dtype=lm_head.dtype).to(device)
logits = torch.matmul(lm_head, hidden_states)
result = F.softmax(logits, dim=0).tolist()
return result[0]
def generate_pairs():
visited = set()
size = 1
while True:
for x in range(size + 1):
for y in range(size + 1):
if (x, y) not in visited:
visited.add((x, y))
yield (x+1, y+1)
size += 1
while zh_idx < len(zh_lines) and vi_idx < len(vi_lines):
for zh_i, vi_i in generate_pairs():
zh_text = ''.join(zh_lines[zh_idx:zh_idx+zh_i])
vi_text = ' '.join(vi_lines[vi_idx:vi_idx+vi_i])
score = get_sim_score(zh_text, vi_text)
if score > 0.5:
break
if zh_i == 1 and vi_i == 1:
continue
score = get_sim_score(zh_lines[zh_idx+zh_i-1], vi_lines[vi_idx+vi_i-1])
if score > 0.5:
zh_i -= 1
vi_i -= 1
break
if zh_i > max_extra_lines or vi_i > max_extra_lines:
end_flag = False
for zh_i, vi_i in generate_pairs():
if zh_i == 1 and vi_i == 1:
continue
for zh_i_offset, vi_i_offset in generate_pairs():
zh_start = zh_idx+zh_i-1
vi_start = vi_idx+vi_i-1
if zh_i+zh_i_offset > max_extra_lines and vi_i+vi_i_offset > max_extra_lines:
break
if zh_i+zh_i_offset > max_extra_lines or vi_i+vi_i_offset > max_extra_lines:
continue
zh_text = ''.join(zh_lines[zh_start:zh_start+zh_i_offset])
vi_text = ' '.join(vi_lines[vi_start:vi_start+vi_i_offset])
score = get_sim_score(zh_text, vi_text)
if score > 0.5:
zh_i -= 1
vi_i -= 1
end_flag = True
break
if end_flag:
break
if zh_i > max_extra_lines or vi_i > max_extra_lines:
with open(output_path,'w',encoding='utf-8') as f:
json.dump(align_list, f, ensure_ascii=False, indent=0)
raise Exception(f'Error! zh line No.{zh_idx+1} vi line No.{vi_idx+1}')
zh_text = '\n'.join(zh_lines[zh_idx:zh_idx+zh_i])
vi_text = '\n'.join(vi_lines[vi_idx:vi_idx+vi_i])
align_list.append({'zh':zh_text, 'vi':vi_text})
new_align = [list(range(zh_idx, zh_idx+zh_i)), list(range(vi_idx, vi_idx+vi_i))]
print(new_align)
if len(align_list) % save_interval == 0:
with open(output_path,'w',encoding='utf-8') as f:
json.dump(align_list, f, ensure_ascii=False, indent=0)
zh_idx += zh_i
vi_idx += vi_i
with open(output_path,'w',encoding='utf-8') as f:
json.dump(align_list, f, ensure_ascii=False, indent=0)
- Downloads last month
- 24
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.