qlora-sqlcoder / model_loader.py
Miguel0918's picture
Create model_loader.py
5a17665 verified
raw
history blame
1.34 kB
# model_loader.py
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
def load_model():
# Define o modelo base e o caminho dos adapters (reposit贸rio atual)
base_model = "defog/sqlcoder-7b-2"
adapter_path = "./" # Aqui, assume que os arquivos dos adapters est茫o no diret贸rio raiz do reposit贸rio
# Carregar o tokenizer
tokenizer = AutoTokenizer.from_pretrained(adapter_path)
tokenizer.pad_token = tokenizer.eos_token
# Carregar o modelo base com quantiza莽茫o (assumindo 4-bit e utiliza莽茫o de fp16)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map="auto",
load_in_4bit=True,
torch_dtype=torch.float16
)
model.config.pad_token_id = tokenizer.pad_token_id
# Aplicar os adapters LoRA a partir do adapter_path
model = PeftModel.from_pretrained(model, adapter_path)
return model, tokenizer
if __name__ == "__main__":
model, tokenizer = load_model()
prompt = "portfolio_transaction_headers(...) JOIN portfolio_transaction_details(...): Find transactions for portfolio 72 involving LTC"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))