teaevo commited on
Commit
0279b82
·
1 Parent(s): a192269

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -44,14 +44,16 @@ def predict(input, history=[]):
44
 
45
  if is_question:
46
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
47
- #sql_outputs = sql_model.generate(**sql_encoding)
48
- #response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
49
-
 
 
50
  bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
51
  history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
52
  response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
53
  response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
54
-
55
  else:
56
  # tokenize the new input sentence
57
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
 
44
 
45
  if is_question:
46
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
47
+ sql_outputs = sql_model.generate(**sql_encoding)
48
+ response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
49
+
50
+ history.append(response)
51
+ '''
52
  bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
53
  history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
54
  response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
55
  response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
56
+ '''
57
  else:
58
  # tokenize the new input sentence
59
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')