File size: 922 Bytes
370f602
252f973
2a56ef6
252f973
81f9daa
 
252f973
81f9daa
 
 
 
 
2a56ef6
81f9daa
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

from transformers import pipeline

pipe = pipeline("text-generation", model="THUDM/LongWriter-llama3.1-8b")
result = pipe("Write a 10000-word China travel guide")
print(result)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-llama3.1-8b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("THUDM/LongWriter-llama3.1-8b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()

query = "Write a 10000-word China travel guide"
prompt = f"[INST] {query} [/INST]"
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
output = model.generate(**input, max_new_tokens=32768, num_beams=1, do_sample=True, temperature=0.5)[0]
response = tokenizer.decode(output[context_length:], skip_special_tokens=True)
print(response)