stmnk commited on
Commit
94677a0
·
1 Parent(s): 0aa1779

add default values for inference params

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -160,20 +160,20 @@ def pygen_func(nl_code_intent):
160
  # return str(answer)
161
  # CT5_URL = "https://api-inference.huggingface.co/models/nielsr/codet5-small-code-summarization-ruby"
162
 
163
- def docgen_func(function_code, temp):
164
- t = float(temp)
165
  req_data = {
166
  "inputs": function_code,
167
  "parameters": {
168
- "min_length": 50, # (Default: None). Integer to define the minimum length in tokens of the output summary.
169
- "max_length": 500, # (Default: None). Integer to define the maximum length in tokens of the output summary.
170
- "top_k": 3, # (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
171
- "top_p": 0.8, # (Default: None). Float to define the tokens that are within the sample` operation of text generation.
172
  # Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
173
  "temperature": t, # (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation.
174
  # 1 means regular sampling, 0 means top_k=1, 100.0 is getting closer to uniform probability.
175
- "repetition_penalty": 50.0, # (Default: None). Float (0.0-100.0). The more a token is used within generation
176
- # the more it is penalized to not be picked in successive generation passes.
177
  "max_time": 80, # (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
178
  # Network can cause some overhead so it will be a soft limit.
179
  },
@@ -195,7 +195,12 @@ iface = gr.Interface(
195
  [
196
  # gr.inputs.Textbox(lines=7, label="Code Intent (NL)", default=task_code),
197
  gr.inputs.Textbox(lines=10, label="Enter Task + Code in Python (PL)", default=task_code),
198
- gr.inputs.Slider(0, 100, label="Temperature"),
 
 
 
 
 
199
  ],
200
  # gr.outputs.Textbox(label="Code Generated PL"))
201
  gr.outputs.Textbox(label="Docstring Generated (NL)"),
 
160
  # return str(answer)
161
  # CT5_URL = "https://api-inference.huggingface.co/models/nielsr/codet5-small-code-summarization-ruby"
162
 
163
+ def docgen_func(function_code, min_length, max_length, top_k, top_p, temp, repetition_penalty):
164
+ m, M, k, p, t, r = int(min_length), int(max_length), int(top_k), float(top_p/100), float(temp), float(repetition_penalty)
165
  req_data = {
166
  "inputs": function_code,
167
  "parameters": {
168
+ "min_length": m, # (Default: None). Integer to define the minimum length in tokens of the output summary.
169
+ "max_length": M, # (Default: None). Integer to define the maximum length in tokens of the output summary.
170
+ "top_k": k, # (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
171
+ "top_p": p, # (Default: None). Float to define the tokens that are within the sample` operation of text generation.
172
  # Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
173
  "temperature": t, # (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation.
174
  # 1 means regular sampling, 0 means top_k=1, 100.0 is getting closer to uniform probability.
175
+ "repetition_penalty": r, # (Default: None). Float (0.0-100.0). The more a token is used within generation
176
+ # the more it is penalized to not be picked in successive generation passes.
177
  "max_time": 80, # (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
178
  # Network can cause some overhead so it will be a soft limit.
179
  },
 
195
  [
196
  # gr.inputs.Textbox(lines=7, label="Code Intent (NL)", default=task_code),
197
  gr.inputs.Textbox(lines=10, label="Enter Task + Code in Python (PL)", default=task_code),
198
+ gr.inputs.Slider(30, 200, default=100, label="Minimum Length (of the output summary, in tokens)"),
199
+ gr.inputs.Slider(200, 500, default=350, label="Maximum Length (of the output summary, in tokens)"),
200
+ gr.inputs.Slider(1, 7, default=3, label="Top K (tokens considered within the sample operation to create new text)"),
201
+ gr.inputs.Slider(0, 100, default=80, label="Top P (probability threshold for next tokens in sample of new text, cumulative"),
202
+ gr.inputs.Slider(0, 100, default=1, label="Temperature (of the sampling operation)"),
203
+ gr.inputs.Slider(0, 100, default=70, label="Repetition Penalty (frequently previously used tokens are downsized)"),
204
  ],
205
  # gr.outputs.Textbox(label="Code Generated PL"))
206
  gr.outputs.Textbox(label="Docstring Generated (NL)"),