Zaid commited on
Commit
2fe7e62
·
verified ·
1 Parent(s): 592570e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -31,7 +31,7 @@ def generate(
31
  message: str,
32
  chat_history: list[dict],
33
  system_prompt: str = "",
34
- max_new_tokens: int = 1024,
35
  temperature: float = 0.6,
36
  top_p: float = 0.9,
37
  top_k: int = 50,
@@ -43,7 +43,7 @@ def generate(
43
  conversation += chat_history
44
  conversation.append({"role": "user", "content": message})
45
 
46
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
47
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
48
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
49
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -54,12 +54,8 @@ def generate(
54
  {"input_ids": input_ids},
55
  streamer=streamer,
56
  max_new_tokens=max_new_tokens,
57
- do_sample=True,
58
- top_p=top_p,
59
- top_k=top_k,
60
  temperature=temperature,
61
- num_beams=1,
62
- repetition_penalty=repetition_penalty,
63
  )
64
  t = Thread(target=model.generate, kwargs=generate_kwargs)
65
  t.start()
@@ -86,7 +82,7 @@ chat_interface = gr.ChatInterface(
86
  minimum=0.1,
87
  maximum=4.0,
88
  step=0.1,
89
- value=0.6,
90
  ),
91
  gr.Slider(
92
  label="Top-p (nucleus sampling)",
 
31
  message: str,
32
  chat_history: list[dict],
33
  system_prompt: str = "",
34
+ max_new_tokens: int = 128,
35
  temperature: float = 0.6,
36
  top_p: float = 0.9,
37
  top_k: int = 50,
 
43
  conversation += chat_history
44
  conversation.append({"role": "user", "content": message})
45
 
46
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
47
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
48
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
49
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
54
  {"input_ids": input_ids},
55
  streamer=streamer,
56
  max_new_tokens=max_new_tokens,
57
+ do_sample=False,
 
 
58
  temperature=temperature,
 
 
59
  )
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start()
 
82
  minimum=0.1,
83
  maximum=4.0,
84
  step=0.1,
85
+ value=0.7,
86
  ),
87
  gr.Slider(
88
  label="Top-p (nucleus sampling)",