File size: 1,964 Bytes
48b2ebf 0f91c7c 48b2ebf 0f91c7c 48b2ebf 0f91c7c 392dd49 48b2ebf 0f91c7c 392dd49 48b2ebf 0f91c7c 48b2ebf 0f91c7c 48b2ebf 0f91c7c 48b2ebf 0f91c7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from peft import PeftModel
import json
import os
# Load tokenizer and base model
base_model = "Qwen/Qwen2-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
# Clean up adapter_config.json before loading adapter
adapter_config_path = "./adapter/adapter_config.json"
if os.path.exists(adapter_config_path):
with open(adapter_config_path, "r") as f:
adapter_config = json.load(f)
for key in ["corda_config", "eva_config", "megatron_config"]:
adapter_config.pop(key, None)
with open(adapter_config_path, "w") as f:
json.dump(adapter_config, f)
# Load adapter
model = PeftModel.from_pretrained(model, "./adapter", is_trainable=False)
model.eval()
# Simple chat function
def chat(prompt):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
streamer=streamer
)
output = tokenizer.decode(generated_ids[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
return output
# Example
if __name__ == "__main__":
while True:
prompt = input("User: ")
if prompt.lower() in ["exit", "quit"]:
break
print("AI:", chat(prompt))
|