schuler commited on
Commit
9110504
·
verified ·
1 Parent(s): c6c83ca

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -3
README.md CHANGED
@@ -36,6 +36,10 @@ The following table shows LaMini training results with the baseline and the opti
36
 
37
  ## Usage:
38
  ```
 
 
 
 
39
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline
40
  from transformers import LlamaTokenizer
41
  import torch
@@ -45,8 +49,8 @@ REPO_NAME = 'schuler/experimental-JP47D55C'
45
  def load_model(local_repo_name):
46
  tokenizer = LlamaTokenizer.from_pretrained(local_repo_name, trust_remote_code=True)
47
  generator_conf = GenerationConfig.from_pretrained(local_repo_name)
48
- model = AutoModelForCausalLM.from_pretrained(local_repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="eager")
49
- # model.to('cuda')
50
  return tokenizer, generator_conf, model
51
 
52
  tokenizer, generator_conf, model = load_model(REPO_NAME)
@@ -57,7 +61,7 @@ except Exception as e:
57
  global_error = f"Failed to load model: {str(e)}"
58
 
59
  def PrintTest(str):
60
- print(generator(str, max_new_tokens=256, do_sample=True, top_p=0.25, repetition_penalty=1.2))
61
 
62
  PrintTest(f"<|user|>\nHello\n<|end|>\n<|assistant|>\n")
63
  PrintTest(f"<|user|>Hello\n<|end|><|assistant|>")
 
36
 
37
  ## Usage:
38
  ```
39
+ !pip install -q -U transformers
40
+ !pip install -q -U accelerate
41
+ !pip install -q -U flash-attn --no-build-isolation
42
+
43
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline
44
  from transformers import LlamaTokenizer
45
  import torch
 
49
  def load_model(local_repo_name):
50
  tokenizer = LlamaTokenizer.from_pretrained(local_repo_name, trust_remote_code=True)
51
  generator_conf = GenerationConfig.from_pretrained(local_repo_name)
52
+ model = AutoModelForCausalLM.from_pretrained(local_repo_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16)
53
+ model.to('cuda')
54
  return tokenizer, generator_conf, model
55
 
56
  tokenizer, generator_conf, model = load_model(REPO_NAME)
 
61
  global_error = f"Failed to load model: {str(e)}"
62
 
63
  def PrintTest(str):
64
+ print(generator(str, max_new_tokens=256, do_sample=True, top_p=0.5, repetition_penalty=1.2))
65
 
66
  PrintTest(f"<|user|>\nHello\n<|end|>\n<|assistant|>\n")
67
  PrintTest(f"<|user|>Hello\n<|end|><|assistant|>")