teaevo commited on
Commit
8b84431
·
1 Parent(s): 56352a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -41,24 +41,31 @@ def predict(input, history=[]):
41
  global sql_response
42
  # Check if the user input is a question
43
  is_question = "?" in input
44
-
45
- # tokenize the new input sentence
46
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
47
-
48
- # append the new user input tokens to the chat history
49
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
50
 
51
  if is_question:
52
- sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
53
- sql_outputs = sql_model.generate(**sql_encoding)
54
- sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # generate a response
57
- history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
58
-
59
- # convert the tokens to text, and then split the responses into the right format
60
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
61
- response = sql_response + " " + [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
62
  return response, history
63
 
64
 
 
41
  global sql_response
42
  # Check if the user input is a question
43
  is_question = "?" in input
 
 
 
 
 
 
44
 
45
  if is_question:
46
+ sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
47
+ #sql_outputs = sql_model.generate(**sql_encoding)
48
+ #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
49
+
50
+ bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
51
+ history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
52
+ response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
53
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
54
+
55
+ else:
56
+ # tokenize the new input sentence
57
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
58
+
59
+ # append the new user input tokens to the chat history
60
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
61
+
62
+ # generate a response
63
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
64
+
65
+ # convert the tokens to text, and then split the responses into the right format
66
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
67
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
68
 
 
 
 
 
 
 
69
  return response, history
70
 
71