teaevo commited on
Commit
a192269
·
1 Parent(s): 0deb7d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -44,15 +44,14 @@ def predict(input, history=[]):
44
 
45
  if is_question:
46
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
47
- sql_outputs = sql_model.generate(**sql_encoding)
48
- response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
49
-
50
- '''
51
  bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
52
  history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
53
  response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
54
  response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
55
- '''
56
  else:
57
  # tokenize the new input sentence
58
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
 
44
 
45
  if is_question:
46
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
47
+ #sql_outputs = sql_model.generate(**sql_encoding)
48
+ #response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
49
+
 
50
  bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
51
  history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
52
  response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
53
  response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
54
+
55
  else:
56
  # tokenize the new input sentence
57
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')