iamrahu commited on
Commit
81f9daa
·
verified ·
1 Parent(s): 252f973

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -1,10 +1,20 @@
1
- # Use a pipeline as a high-level helper
2
  from transformers import pipeline
3
 
4
  pipe = pipeline("text-generation", model="THUDM/LongWriter-llama3.1-8b")
5
-
6
- # Load model directly
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
8
 
9
- tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-llama3.1-8b")
10
- model = AutoModelForCausalLM.from_pretrained("THUDM/LongWriter-llama3.1-8b")
 
 
 
 
 
 
1
+ pip install transformers
2
  from transformers import pipeline
3
 
4
  pipe = pipeline("text-generation", model="THUDM/LongWriter-llama3.1-8b")
5
+ result = pipe("Write a 10000-word China travel guide")
6
+ print(result)
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import torch
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-llama3.1-8b", trust_remote_code=True)
11
+ model = AutoModelForCausalLM.from_pretrained("THUDM/LongWriter-llama3.1-8b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
12
+ model = model.eval()
13
 
14
+ query = "Write a 10000-word China travel guide"
15
+ prompt = f"[INST] {query} [/INST]"
16
+ input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
17
+ context_length = input.input_ids.shape[-1]
18
+ output = model.generate(**input, max_new_tokens=32768, num_beams=1, do_sample=True, temperature=0.5)[0]
19
+ response = tokenizer.decode(output[context_length:], skip_special_tokens=True)
20
+ print(response)