WN-zh-vi-sim-v0.3的GPTQ Int4量化版本

Bản quant GPTQ Int4 của WN-zh-vi-sim-v0.3

模型用于对齐中文文本和越南语文本

Mô hình dùng để align văn bản tiếng Trung và tiếng Việt

import os
import json
import torch
import torch.nn.functional as F
from vllm import LLM
from vllm.config import PoolerConfig
from huggingface_hub import hf_hub_download

model_path = 'CjangCjengh/WN-zh-vi-sim-v0.3-GPTQ-Int4'
zh_text_path = 'zh.txt'
vi_text_path = 'vi.txt'
output_path = 'output.json'
save_interval = 50
device = 'cuda'
cpu_offload_gb = 0

lm_head_filename = 'yes_no_lm_head.pt'
lm_head_path = hf_hub_download(repo_id=model_path, filename=lm_head_filename, local_dir='.')

zh_idx = 0
vi_idx = 0
max_extra_lines = 5

align_list = []
if os.path.exists(output_path):
    with open(output_path,'r',encoding='utf-8') as f:
        align_list = json.load(f)
        zh_idx = sum([i['zh'].count('\n')+1 for i in align_list if i['zh']])
        vi_idx = sum([i['vi'].count('\n')+1 for i in align_list if i['vi']])

lm_head = torch.load(lm_head_path)
lm_head.to(device)

llm = LLM(model=model_path, cpu_offload_gb=cpu_offload_gb, enforce_eager=True, task='embed', override_pooler_config=PoolerConfig(pooling_type='ALL'))

zh_lines = open(zh_text_path,'r',encoding='utf-8').readlines()
vi_lines = open(vi_text_path,'r',encoding='utf-8').readlines()
zh_lines = [l.strip() for l in zh_lines]
vi_lines = [l.strip() for l in vi_lines]


def get_sim_score(src_text, tgt_text):
    text = f'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n下面的中文段落和越南语段落的内容是否完全对应,不存在缺漏?(回答Yes或No)\n\n中文:\n{src_text}\n\n越南语:{tgt_text}<|im_end|>\n<|im_start|>assistant\n'
    outputs = llm.encode(text)
    with torch.inference_mode():
        hidden_states = outputs[0].outputs.data[-1].to(dtype=lm_head.dtype).to(device)
        logits = torch.matmul(lm_head, hidden_states)
        result = F.softmax(logits, dim=0).tolist()
    return result[0]

def generate_pairs():
    visited = set()
    size = 1
    while True:
        for x in range(size + 1):
            for y in range(size + 1):
                if (x, y) not in visited:
                    visited.add((x, y))
                    yield (x+1, y+1)
        size += 1

while zh_idx < len(zh_lines) and vi_idx < len(vi_lines):
    for zh_i, vi_i in generate_pairs():
        zh_text = ''.join(zh_lines[zh_idx:zh_idx+zh_i])
        vi_text = ' '.join(vi_lines[vi_idx:vi_idx+vi_i])
        score = get_sim_score(zh_text, vi_text)
        if score > 0.5:
            break
        if zh_i == 1 and vi_i == 1:
            continue
        score = get_sim_score(zh_lines[zh_idx+zh_i-1], vi_lines[vi_idx+vi_i-1])
        if score > 0.5:
            zh_i -= 1
            vi_i -= 1
            break
        if zh_i > max_extra_lines or vi_i > max_extra_lines:
            end_flag = False
            for zh_i, vi_i in generate_pairs():
                if zh_i == 1 and vi_i == 1:
                    continue
                for zh_i_offset, vi_i_offset in generate_pairs():
                    zh_start = zh_idx+zh_i-1
                    vi_start = vi_idx+vi_i-1
                    if zh_i+zh_i_offset > max_extra_lines and vi_i+vi_i_offset > max_extra_lines:
                        break
                    if zh_i+zh_i_offset > max_extra_lines or vi_i+vi_i_offset > max_extra_lines:
                        continue
                    zh_text = ''.join(zh_lines[zh_start:zh_start+zh_i_offset])
                    vi_text = ' '.join(vi_lines[vi_start:vi_start+vi_i_offset])
                    score = get_sim_score(zh_text, vi_text)
                    if score > 0.5:
                        zh_i -= 1
                        vi_i -= 1
                        end_flag = True
                        break
                if end_flag:
                    break
                if zh_i > max_extra_lines or vi_i > max_extra_lines:
                    with open(output_path,'w',encoding='utf-8') as f:
                        json.dump(align_list, f, ensure_ascii=False, indent=0)
                    raise Exception(f'Error! zh line No.{zh_idx+1} vi line No.{vi_idx+1}')

    zh_text = '\n'.join(zh_lines[zh_idx:zh_idx+zh_i])
    vi_text = '\n'.join(vi_lines[vi_idx:vi_idx+vi_i])
    align_list.append({'zh':zh_text, 'vi':vi_text})

    new_align = [list(range(zh_idx, zh_idx+zh_i)), list(range(vi_idx, vi_idx+vi_i))]
    print(new_align)

    if len(align_list) % save_interval == 0:
        with open(output_path,'w',encoding='utf-8') as f:
            json.dump(align_list, f, ensure_ascii=False, indent=0)

    zh_idx += zh_i
    vi_idx += vi_i

with open(output_path,'w',encoding='utf-8') as f:
    json.dump(align_list, f, ensure_ascii=False, indent=0)
Downloads last month
24
Safetensors
Model size
3.33B params
Tensor type
I32
·
FP16
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.