OsamaMo commited on
Commit
193a126
·
verified ·
1 Parent(s): c19331a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +134 -8
README.md CHANGED
@@ -61,21 +61,22 @@ Can be fine-tuned further for specific databases or Arabic dialect adaptations.
61
  - Ensure compatibility with specific database schemas.
62
 
63
  ## How to Get Started with the Model
64
-
65
  ```python
66
  from transformers import AutoModelForCausalLM, AutoTokenizer
67
  import torch
 
68
 
69
  device = "cuda" if torch.cuda.is_available() else "cpu"
70
  base_model_id = "Qwen/Qwen2.5-1.5B-Instruct"
71
  finetuned_model_id = "OsamaMo/Arabic_Text-To-SQL_using_Qwen2.5-1.5B"
72
 
 
73
  model = AutoModelForCausalLM.from_pretrained(
74
  base_model_id,
75
  device_map="auto",
76
  torch_dtype=torch.bfloat16
77
  )
78
-
79
  model.load_adapter(finetuned_model_id)
80
 
81
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
@@ -86,25 +87,150 @@ def generate_resp(messages):
86
  tokenize=False,
87
  add_generation_prompt=True
88
  )
89
-
90
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
91
-
92
  generated_ids = model.generate(
93
  model_inputs.input_ids,
94
  max_new_tokens=1024,
95
- do_sample=False, top_k=None, temperature=None, top_p=None,
96
  )
97
-
98
  generated_ids = [
99
  output_ids[len(input_ids):]
100
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
101
  ]
102
-
103
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
104
-
105
  return response
106
  ```
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  ## Training Details
109
 
110
  ### Training Data
 
61
  - Ensure compatibility with specific database schemas.
62
 
63
  ## How to Get Started with the Model
64
+ ### Load Model
65
  ```python
66
  from transformers import AutoModelForCausalLM, AutoTokenizer
67
  import torch
68
+ import re
69
 
70
  device = "cuda" if torch.cuda.is_available() else "cpu"
71
  base_model_id = "Qwen/Qwen2.5-1.5B-Instruct"
72
  finetuned_model_id = "OsamaMo/Arabic_Text-To-SQL_using_Qwen2.5-1.5B"
73
 
74
+ # Load the base model and adapter for fine-tuning
75
  model = AutoModelForCausalLM.from_pretrained(
76
  base_model_id,
77
  device_map="auto",
78
  torch_dtype=torch.bfloat16
79
  )
 
80
  model.load_adapter(finetuned_model_id)
81
 
82
  tokenizer = AutoTokenizer.from_pretrained(base_model_id)
 
87
  tokenize=False,
88
  add_generation_prompt=True
89
  )
 
90
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
 
91
  generated_ids = model.generate(
92
  model_inputs.input_ids,
93
  max_new_tokens=1024,
94
+ do_sample=False, temperature= False,
95
  )
 
96
  generated_ids = [
97
  output_ids[len(input_ids):]
98
  for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
99
  ]
 
100
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
101
  return response
102
  ```
103
 
104
+
105
+
106
+ ### Example Usage
107
+ ```python
108
+
109
+ # Production-ready system message for SQL generation
110
+ system_message = (
111
+ "You are a highly advanced Arabic text-to-SQL converter. Your mission is to Understand first the db schema and reltions between it and then accurately transform Arabic "
112
+ "natural language queries into SQL queries with precision and clarity.\n"
113
+ )
114
+
115
+ def get_sql_query(db_schema, arabic_query):
116
+ # Construct the instruction message including the DB schema and the Arabic query
117
+ instruction_message = "\n".join([
118
+ "## DB-Schema:",
119
+ db_schema,
120
+ "",
121
+ "## User-Prompt:",
122
+ arabic_query,
123
+ "# Output SQL:",
124
+ "```SQL"
125
+ ])
126
+
127
+ messages = [
128
+ {"role": "system", "content": system_message},
129
+ {"role": "user", "content": instruction_message}
130
+ ]
131
+
132
+ response = generate_resp(messages)
133
+
134
+ # Extract the SQL query from the response using a regex to capture text within the ```sql markdown block
135
+ match = re.search(r"```sql\s*(.*?)\s*```", response, re.DOTALL | re.IGNORECASE)
136
+ if match:
137
+ sql_query = match.group(1).strip()
138
+ return sql_query
139
+ else:
140
+ return response.strip()
141
+
142
+ # Example usage:
143
+ example_db_schema = r'''{
144
+ 'Pharmcy':
145
+ CREATE TABLE `purchase` (
146
+ `BARCODE` varchar(20) NOT NULL,
147
+ `NAME` varchar(50) NOT NULL,
148
+ `TYPE` varchar(20) NOT NULL,
149
+ `COMPANY_NAME` varchar(20) NOT NULL,
150
+ `QUANTITY` int NOT NULL,
151
+ `PRICE` double NOT NULL,
152
+ `AMOUNT` double NOT NULL,
153
+ PRIMARY KEY (`BARCODE`),
154
+ KEY `fkr3` (`COMPANY_NAME`),
155
+ CONSTRAINT `fkr3` FOREIGN KEY (`COMPANY_NAME`) REFERENCES `company` (`NAME`) ON DELETE CASCADE ON UPDATE CASCADE
156
+ ) ENGINE=InnoDB DEFAULT CHARSET=latin1
157
+
158
+ CREATE TABLE `sales` (
159
+ `BARCODE` varchar(20) NOT NULL,
160
+ `NAME` varchar(50) NOT NULL,
161
+ `TYPE` varchar(10) NOT NULL,
162
+ `DOSE` varchar(10) NOT NULL,
163
+ `QUANTITY` int NOT NULL,
164
+ `PRICE` double NOT NULL,
165
+ `AMOUNT` double NOT NULL,
166
+ `DATE` varchar(15) NOT NULL
167
+ ) ENGINE=InnoDB DEFAULT CHARSET=latin1
168
+
169
+ CREATE TABLE `users` (
170
+ `ID` int NOT NULL,
171
+ `NAME` varchar(50) NOT NULL,
172
+ `DOB` varchar(20) NOT NULL,
173
+ `ADDRESS` varchar(100) NOT NULL,
174
+ `PHONE` varchar(20) NOT NULL,
175
+ `SALARY` double NOT NULL,
176
+ `PASSWORD` varchar(20) NOT NULL,
177
+ PRIMARY KEY (`ID`)
178
+ ) ENGINE=InnoDB DEFAULT CHARSET=latin1
179
+
180
+ CREATE TABLE `history_sales` (
181
+ `USER_NAME` varchar(20) NOT NULL,
182
+ `BARCODE` varchar(20) NOT NULL,
183
+ `NAME` varchar(50) NOT NULL,
184
+ `TYPE` varchar(10) NOT NULL,
185
+ `DOSE` varchar(10) NOT NULL,
186
+ `QUANTITY` int NOT NULL,
187
+ `PRICE` double NOT NULL,
188
+ `AMOUNT` double NOT NULL,
189
+ `DATE` varchar(15) NOT NULL,
190
+ `TIME` varchar(20) NOT NULL
191
+ ) ENGINE=InnoDB DEFAULT CHARSET=latin1
192
+
193
+ CREATE TABLE `expiry` (
194
+ `PRODUCT_NAME` varchar(50) NOT NULL,
195
+ `PRODUCT_CODE` varchar(20) NOT NULL,
196
+ `DATE_OF_EXPIRY` varchar(10) NOT NULL,
197
+ `QUANTITY_REMAIN` int NOT NULL
198
+ ) ENGINE=InnoDB DEFAULT CHARSET=latin1
199
+
200
+ CREATE TABLE `drugs` (
201
+ `NAME` varchar(50) NOT NULL,
202
+ `TYPE` varchar(20) NOT NULL,
203
+ `BARCODE` varchar(20) NOT NULL,
204
+ `DOSE` varchar(10) NOT NULL,
205
+ `CODE` varchar(10) NOT NULL,
206
+ `COST_PRICE` double NOT NULL,
207
+ `SELLING_PRICE` double NOT NULL,
208
+ `EXPIRY` varchar(20) NOT NULL,
209
+ `COMPANY_NAME` varchar(50) NOT NULL,
210
+ `PRODUCTION_DATE` date NOT NULL,
211
+ `EXPIRATION_DATE` date NOT NULL,
212
+ `PLACE` varchar(20) NOT NULL,
213
+ `QUANTITY` int NOT NULL,
214
+ PRIMARY KEY (`BARCODE`)
215
+ ) ENGINE=InnoDB DEFAULT CHARSET=latin1
216
+
217
+ CREATE TABLE `company` (
218
+ `NAME` varchar(50) NOT NULL,
219
+ `ADDRESS` varchar(50) NOT NULL,
220
+ `PHONE` varchar(20) NOT NULL,
221
+ PRIMARY KEY (`NAME`)
222
+ ) ENGINE=InnoDB DEFAULT CHARSET=latin1
223
+
224
+ Answer the following questions about this schema:
225
+ }'''
226
+
227
+ example_arabic_query = "اريد الباركود الخاص بدواء يبداء اسمه بحرف 's'"
228
+
229
+ sql_result = get_sql_query(example_db_schema, example_arabic_query)
230
+ print("استعلام SQL الناتج:")
231
+ print(sql_result)
232
+ ```
233
+
234
  ## Training Details
235
 
236
  ### Training Data