karimouda commited on
Commit
002f3bc
·
verified ·
1 Parent(s): cf713c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -3
app.py CHANGED
@@ -36,8 +36,11 @@ model.eval()
36
  def generate(
37
  message: str,
38
  chat_history: list[dict],
39
- max_new_tokens: int = 1048,
40
- temperature: float = 0.01,
 
 
 
41
  ) -> Iterator[str]:
42
  conversation = chat_history.copy()
43
  conversation.append({"role": "user", "content": message})
@@ -54,8 +57,11 @@ def generate(
54
  streamer=streamer,
55
  max_new_tokens=max_new_tokens,
56
  do_sample=True,
 
 
57
  temperature=temperature,
58
  num_beams=1,
 
59
  )
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start()
@@ -69,12 +75,40 @@ def generate(
69
  demo = gr.ChatInterface(
70
  fn=generate,
71
  additional_inputs=[
 
 
 
 
 
 
 
72
  gr.Slider(
73
  label="Temperature",
74
  minimum=0.1,
75
  maximum=4.0,
76
  step=0.1,
77
- value=0.01,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ),
79
  ],
80
  stop_btn=None,
 
36
  def generate(
37
  message: str,
38
  chat_history: list[dict],
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.2,
44
  ) -> Iterator[str]:
45
  conversation = chat_history.copy()
46
  conversation.append({"role": "user", "content": message})
 
57
  streamer=streamer,
58
  max_new_tokens=max_new_tokens,
59
  do_sample=True,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
  temperature=temperature,
63
  num_beams=1,
64
+ repetition_penalty=repetition_penalty
65
  )
66
  t = Thread(target=model.generate, kwargs=generate_kwargs)
67
  t.start()
 
75
  demo = gr.ChatInterface(
76
  fn=generate,
77
  additional_inputs=[
78
+ gr.Slider(
79
+ label="Max new tokens",
80
+ minimum=1,
81
+ maximum=MAX_MAX_NEW_TOKENS,
82
+ step=1,
83
+ value=DEFAULT_MAX_NEW_TOKENS,
84
+ ),
85
  gr.Slider(
86
  label="Temperature",
87
  minimum=0.1,
88
  maximum=4.0,
89
  step=0.1,
90
+ value=0.6,
91
+ ),
92
+ gr.Slider(
93
+ label="Top-p (nucleus sampling)",
94
+ minimum=0.05,
95
+ maximum=1.0,
96
+ step=0.05,
97
+ value=0.9,
98
+ ),
99
+ gr.Slider(
100
+ label="Top-k",
101
+ minimum=1,
102
+ maximum=1000,
103
+ step=1,
104
+ value=50,
105
+ ),
106
+ gr.Slider(
107
+ label="Repetition penalty",
108
+ minimum=1.0,
109
+ maximum=2.0,
110
+ step=0.05,
111
+ value=1.2,
112
  ),
113
  ],
114
  stop_btn=None,