nileshhanotia commited on
Commit
a5e04ba
·
verified ·
1 Parent(s): ca58184

Create sql_generator.py

Browse files
Files changed (1) hide show
  1. sql_generator.py +84 -0
sql_generator.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import requests
4
+ from config import ACCESS_TOKEN, SHOP_NAME
5
+
6
+ class SQLGenerator:
7
+ def __init__(self):
8
+ self.model_name = "premai-io/prem-1B-SQL"
9
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
+
12
+ def generate_query(self, natural_language_query):
13
+ schema_info = """
14
+ CREATE TABLE products (
15
+ id DECIMAL(8,2) PRIMARY KEY,
16
+ title VARCHAR(255),
17
+ body_html VARCHAR(255),
18
+ vendor VARCHAR(255),
19
+ product_type VARCHAR(255),
20
+ created_at VARCHAR(255),
21
+ handle VARCHAR(255),
22
+ updated_at DATE,
23
+ published_at VARCHAR(255),
24
+ template_suffix VARCHAR(255),
25
+ published_scope VARCHAR(255),
26
+ tags VARCHAR(255),
27
+ status VARCHAR(255),
28
+ admin_graphql_api_id DECIMAL(8,2),
29
+ variants VARCHAR(255),
30
+ options VARCHAR(255),
31
+ images VARCHAR(255),
32
+ image VARCHAR(255)
33
+ );
34
+ """
35
+
36
+ prompt = f"""### Task: Generate a SQL query to answer the following question.
37
+ ### Database Schema:
38
+ {schema_info}
39
+ ### Question: {natural_language_query}
40
+ ### SQL Query:"""
41
+
42
+ inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device)
43
+ outputs = self.model.generate(
44
+ inputs["input_ids"],
45
+ max_length=256,
46
+ do_sample=True, # Enable sampling to use temperature
47
+ num_return_sequences=1,
48
+ eos_token_id=self.tokenizer.eos_token_id,
49
+ pad_token_id=self.tokenizer.pad_token_id,
50
+ temperature=0.7, # Allow temperature to affect output
51
+ top_k=50 # Consider top k predictions for variability
52
+ )
53
+
54
+ generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
55
+ return generated_query # Return the generated SQL query
56
+
57
+ def fetch_shopify_data(self, endpoint):
58
+ headers = {
59
+ 'X-Shopify-Access-Token': ACCESS_TOKEN,
60
+ 'Content-Type': 'application/json'
61
+ }
62
+ url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
63
+ response = requests.get(url, headers=headers)
64
+
65
+ if response.status_code == 200:
66
+ return response.json()
67
+ else:
68
+ print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
69
+ return None
70
+
71
+ # Example of how to use the SQLGenerator class
72
+ if __name__ == "__main__":
73
+ sql_generator = SQLGenerator()
74
+
75
+ # Example natural language query
76
+ natural_language_query = "Show me shirts with red color"
77
+
78
+ # Generate SQL query
79
+ sql_query = sql_generator.generate_query(natural_language_query)
80
+ print(f"Generated SQL Query: {sql_query}")
81
+
82
+ # Fetch data from Shopify (example endpoint)
83
+ shopify_data = sql_generator.fetch_shopify_data("products")
84
+ print(f"Shopify Data: {shopify_data}")