metadata
frameworks:
- Pytorch
license: other
tasks:
- text-generation
Model Card for NL2SQL-StarCoder-15B
模型介绍
NL2SQL-StarCoder-15B 是在基础模型 StarCoder 上通过 QLoRA 对自然语言生成SQL任务进行微调的 15B Code-LLM。
Requirements
- python>=3.8
- pytorch>=2.0.0
- transformers==4.32.0
- CUDA 11.4
推理数据格式
推理数据为模型在训练数据格式下拼接的字符串形式,它也是推理时输入prompt拼接的方式:
"""
<|user|>
/* Given the following database schema: */
CREATE TABLE "table_name" (
"col1" int,
...
...
)
/* Write a sql to answer the following question: {问题} */
<|assistant|>
```sql
{输出SQL}
```<|end|>
"""
快速开始
import torch
from modelscope import snapshot_download, AutoModelForCausalLM, AutoTokenizer,GenerationConfig
model_dir = snapshot_download("iic/NL2SQL-StarCoder-15B")
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto",
trust_remote_code=True, torch_dtype=torch.float16)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<fim_pad>")
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
tokenizer.pad_token = "<fim_pad>"
tokenizer.eos_token = "<|endoftext|>"
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto",
trust_remote_code=True, torch_dtype=torch.float16)
model.eval()
text = '<|user|>\n/* Given the following database schema: */\nCREATE TABLE "singer" (\n"Singer_ID" int,\n"Name" text,\n"Country" text,\n"Song_Name" text,\n"Song_release_year" text,\n"Age" int,\n"Is_male" bool,\nPRIMARY KEY ("Singer_ID")\n)\n\n/* Write a sql to answer the following question: Show countries where a singer above age 40 and a singer below 30 are from. */<|end|>\n'
inputs = tokenizer(text, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda")
outputs = model.generate(
inputs=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=512,
top_p=0.95,
temperature=0.1,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(gen_text)