Sansa commited on
Commit
fb9f596
·
verified ·
1 Parent(s): fd7c2dc

Update README.md

Browse files

add inference example

Files changed (1) hide show
  1. README.md +44 -0
README.md CHANGED
@@ -24,5 +24,49 @@ Training recipe:
24
 
25
  - GitHub: https://github.com/apple/ml-diffucoder
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  #### Acknowledgement
28
  To power this HuggingFace model release, we reuse [Dream](https://huggingface.co/Dream-org/Dream-v0-Base-7B)'s modeling architecture and generation utils.
 
24
 
25
  - GitHub: https://github.com/apple/ml-diffucoder
26
 
27
+ ```
28
+ import torch
29
+ from transformers import AutoModel, AutoTokenizer
30
+
31
+ model_path = "apple/DiffuCoder-7B-cpGRPO"
32
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
33
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
34
+ model = model.to("cuda").eval()
35
+
36
+ query = "Write a function to find the shared elements from the given two lists."
37
+ prompt = f"""<|im_start|>system
38
+ You are a helpful assistant.<|im_end|>
39
+ <|im_start|>user
40
+ {query.strip()}
41
+ <|im_end|>
42
+ <|im_start|>assistant
43
+ """ ## following the template of qwen; you can also use apply_chat_template function
44
+
45
+ TOKEN_PER_STEP = 1 # diffusion timesteps * TOKEN_PER_STEP = total new tokens
46
+
47
+ inputs = tokenizer(prompt, return_tensors="pt")
48
+ input_ids = inputs.input_ids.to(device="cuda")
49
+ attention_mask = inputs.attention_mask.to(device="cuda")
50
+
51
+ output = model.diffusion_generate(
52
+ input_ids,
53
+ attention_mask=attention_mask,
54
+ max_new_tokens=256,
55
+ output_history=True,
56
+ return_dict_in_generate=True,
57
+ steps=256//TOKEN_PER_STEP,
58
+ temperature=0.4,
59
+ top_p=0.95,
60
+ alg="entropy",
61
+ alg_temp=0.,
62
+ )
63
+ generations = [
64
+ tokenizer.decode(g[len(p) :].tolist())
65
+ for p, g in zip(input_ids, output.sequences)
66
+ ]
67
+
68
+ print(generations[0].split('<|dlm_pad|>')[0])
69
+ ```
70
+
71
  #### Acknowledgement
72
  To power this HuggingFace model release, we reuse [Dream](https://huggingface.co/Dream-org/Dream-v0-Base-7B)'s modeling architecture and generation utils.