Zaid commited on
Commit
f073a1c
·
verified ·
1 Parent(s): 1f2b852

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -8,7 +8,7 @@ import spaces
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
- MAX_MAX_NEW_TOKENS = 2048
12
  DEFAULT_MAX_NEW_TOKENS = 1024
13
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
14
  model = None
@@ -19,7 +19,8 @@ def load_model():
19
  model_id = "stabilityai/ar-stablelm-2-chat"
20
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
21
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
22
- tokenizer.use_default_system_prompt = False
 
23
 
24
 
25
  def generate(
@@ -49,8 +50,11 @@ def generate(
49
  {"input_ids": input_ids},
50
  streamer=streamer,
51
  max_new_tokens=max_new_tokens,
52
- do_sample=False,
 
53
  temperature=temperature,
 
 
54
  )
55
  t = Thread(target=model.generate, kwargs=generate_kwargs)
56
  t.start()
@@ -115,7 +119,7 @@ with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
115
  try:
116
  login(token = token)
117
  load_model()
118
- return f"Authenticated successfully"
119
  except:
120
  return "Invalid token. Please try again."
121
 
@@ -129,4 +133,4 @@ with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
129
 
130
 
131
  if __name__ == "__main__":
132
- demo.queue(max_size=20).launch()
 
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
+ MAX_MAX_NEW_TOKENS = 128
12
  DEFAULT_MAX_NEW_TOKENS = 1024
13
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
14
  model = None
 
19
  model_id = "stabilityai/ar-stablelm-2-chat"
20
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True)
21
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
22
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
23
+
24
 
25
 
26
  def generate(
 
50
  {"input_ids": input_ids},
51
  streamer=streamer,
52
  max_new_tokens=max_new_tokens,
53
+ do_sample=True,
54
+ eos_token_id=tokenizer.eos_token_id, # Stop generation at <EOS>
55
  temperature=temperature,
56
+ top_p=top_p,
57
+ top_k=top_k
58
  )
59
  t = Thread(target=model.generate, kwargs=generate_kwargs)
60
  t.start()
 
119
  try:
120
  login(token = token)
121
  load_model()
122
+ return "Authenticated successfully"
123
  except:
124
  return "Invalid token. Please try again."
125
 
 
133
 
134
 
135
  if __name__ == "__main__":
136
+ demo.queue(max_size=20).launch()