beingpraveen commited on
Commit
7916c49
·
1 Parent(s): 05672c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -21
app.py CHANGED
@@ -2,28 +2,15 @@ import streamlit as st
2
  import layer
3
  from transformers import AutoModelWithLMHead, AutoTokenizer
4
 
5
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
6
- model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
7
-
8
- def get_sql(query):
9
- input_text = "translate English to SQL: %s </s>" % query
10
- features = tokenizer([input_text], return_tensors='pt')
11
-
12
- output = model.generate(input_ids=features['input_ids'],
13
- attention_mask=features['attention_mask'])
14
-
15
- return tokenizer.decode(output[0])
16
-
17
- # model = layer.get_model('layer/t5-fine-tuning-with-layer/models/t5-english-to-sql').get_train()
18
- # tokenizer = layer.get_model('layer/t5-fine-tuning-with-layer/models/t5-tokenizer').get_train()
19
-
20
- # def convert(query):
21
- # inputs = tokenizer.encode(f"translate English to SQL: {query}", return_tensors="pt")
22
- # outputs = model.generate(inputs, max_length=1024)
23
- # sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- # return sql
25
 
 
 
 
 
 
26
 
27
  query = st.text_input("Enter Text here", value="")
28
- output = get_sql(query)
29
  st.text_area(label="Output Sql Query:", value=output, height=100)
 
2
  import layer
3
  from transformers import AutoModelWithLMHead, AutoTokenizer
4
 
5
+ model = layer.get_model('layer/t5-fine-tuning-with-layer/models/t5-english-to-sql').get_train()
6
+ tokenizer = layer.get_model('layer/t5-fine-tuning-with-layer/models/t5-tokenizer').get_train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ def convert(query):
9
+ inputs = tokenizer.encode(f"translate English to SQL: {query}", return_tensors="pt")
10
+ outputs = model.generate(inputs, max_length=1024)
11
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
12
+ return sql
13
 
14
  query = st.text_input("Enter Text here", value="")
15
+ output = convert(query)
16
  st.text_area(label="Output Sql Query:", value=output, height=100)