ejschwartz commited on
Commit
2b63412
·
1 Parent(s): 2b3768e

handle no fields

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -94,29 +94,31 @@ def infer(code):
94
  )
95
 
96
  field_prompt_result, fields, field_helper_result = field_prompt(code)
97
- field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[
98
- :, : 8192 - 1024
99
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- field_output = fielddecoder_model.generate(
102
- input_ids=field_input_ids,
103
- max_new_tokens=1024,
104
- num_beams=4,
105
- num_return_sequences=1,
106
- do_sample=False,
107
- early_stopping=False,
108
- pad_token_id=0,
109
- eos_token_id=0,
110
- )[0]
111
- field_output = tokenizer.decode(
112
- field_output[var_input_ids.size(1) :],
113
- skip_special_tokens=True,
114
- clean_up_tokenization_spaces=True,
115
- )
116
-
117
- var_output = first_var + ":" + var_output
118
- if len(fields) > 0:
119
  field_output = fields[0] + ":" + field_output
 
120
  return var_output, field_output, varstring
121
 
122
 
 
94
  )
95
 
96
  field_prompt_result, fields, field_helper_result = field_prompt(code)
97
+ if len(fields) == 0:
98
+ field_output = "No fields"
99
+ else:
100
+ field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[
101
+ :, : 8192 - 1024
102
+ ]
103
+
104
+ field_output = fielddecoder_model.generate(
105
+ input_ids=field_input_ids,
106
+ max_new_tokens=1024,
107
+ num_beams=4,
108
+ num_return_sequences=1,
109
+ do_sample=False,
110
+ early_stopping=False,
111
+ pad_token_id=0,
112
+ eos_token_id=0,
113
+ )[0]
114
+ field_output = tokenizer.decode(
115
+ field_output[var_input_ids.size(1) :],
116
+ skip_special_tokens=True,
117
+ clean_up_tokenization_spaces=True,
118
+ )
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  field_output = fields[0] + ":" + field_output
121
+ var_output = first_var + ":" + var_output
122
  return var_output, field_output, varstring
123
 
124