burhan112 commited on
Commit
0a17c6b
·
verified ·
1 Parent(s): 62882a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -33
app.py CHANGED
@@ -136,45 +136,38 @@ class Transformer(nn.Module):
136
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
 
138
  # Load tokenizers
139
- sp_pseudo = spm.SentencePieceProcessor(model_file="pseudocode_tokenizer.model") # For decoding pseudocode (target)
140
- sp_code = spm.SentencePieceProcessor(model_file="code_tokenizer.model") # For encoding C++ (source)
141
 
142
  # Load the full saved model (architecture + weights)
143
- model_path = "code2pseudo.pth"
144
  model = torch.load(model_path, map_location=device, weights_only=False)
145
  model.eval()
146
  model = model.to(device)
147
 
148
- def generate_pseudocode(cpp_code, max_len=500):
149
  """Generate pseudocode from C++ code with streaming output."""
150
  model.eval()
151
  src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ code
152
  tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
153
 
154
  generated_tokens = [2] # Start with <START>
155
- eos_id = sp_pseudo.eos_id() # Dynamically get <EOS> ID
156
- print(f"Input C++ tokens: {sp_code.encode_as_ids(cpp_code)}") # Debug input
157
- print(f"Using EOS ID: {eos_id}") # Debug EOS ID
158
-
159
  with torch.no_grad():
160
- for i in range(max_len):
161
  output = model(src, tgt)
162
  next_token = output[:, -1, :].argmax(-1).item()
163
  generated_tokens.append(next_token)
164
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
165
  response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
166
- print(f"Step {i}: Next token = {next_token}, Partial output = {response}") # Debug step
167
  yield response # Yield partial output
168
- if next_token == eos_id: # Stop at <EOS>
169
- print("EOS detected, stopping generation.")
170
  break
171
- print("Generation complete or max length reached.")
172
-
173
  yield response # Final output
174
 
175
- def generate_output(cpp_code):
176
- """Wrapper for Gradio interface with streaming."""
177
- for response in generate_pseudocode(cpp_code, max_len=500):
178
  yield response
179
 
180
  # Gradio UI setup with Blocks
@@ -186,28 +179,17 @@ with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
186
  placeholder="e.g., 'int x = 5; for(int i=0; i<x; i++) cout << i;'",
187
  lines=5
188
  )
189
- submit_btn = gr.Button("Submit", variant="primary", elem_classes="btn-blue")
190
  pseudocode_output = gr.Textbox(
191
  label="Generated Pseudocode",
192
  lines=5
193
  )
194
 
195
  submit_btn.click(
196
- fn=generate_output,
197
- inputs=[cpp_input],
198
  outputs=pseudocode_output
199
  )
200
 
201
- demo.launch()
202
-
203
- # Custom CSS
204
- demo.css = """
205
- .btn-blue {
206
- background-color: #007bff;
207
- color: white;
208
- border: none;
209
- }
210
- .btn-blue:hover {
211
- background-color: #0056b3;
212
- }
213
- """
 
136
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
 
138
  # Load tokenizers
139
+ sp_pseudo = spm.SentencePieceProcessor(model_file="pseudo.model") # For decoding pseudocode (target)
140
+ sp_code = spm.SentencePieceProcessor(model_file="code.model") # For encoding C++ (source)
141
 
142
  # Load the full saved model (architecture + weights)
143
+ model_path = "transformer_cpp_to_pseudo.pth"
144
  model = torch.load(model_path, map_location=device, weights_only=False)
145
  model.eval()
146
  model = model.to(device)
147
 
148
+ def generate_pseudocode(cpp_code, max_len):
149
  """Generate pseudocode from C++ code with streaming output."""
150
  model.eval()
151
  src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ code
152
  tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
153
 
154
  generated_tokens = [2] # Start with <START>
155
+ response = ""
 
 
 
156
  with torch.no_grad():
157
+ for _ in range(max_len):
158
  output = model(src, tgt)
159
  next_token = output[:, -1, :].argmax(-1).item()
160
  generated_tokens.append(next_token)
161
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
162
  response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
 
163
  yield response # Yield partial output
164
+ if next_token == 3: # <END>=3 (adjust if your EOS ID differs)
 
165
  break
 
 
166
  yield response # Final output
167
 
168
+ def respond(message, history, max_tokens):
169
+ """Wrapper for Gradio interface."""
170
+ for response in generate_pseudocode(message, max_tokens):
171
  yield response
172
 
173
  # Gradio UI setup with Blocks
 
179
  placeholder="e.g., 'int x = 5; for(int i=0; i<x; i++) cout << i;'",
180
  lines=5
181
  )
182
+ submit_btn = gr.Button("Submit", variant="primary")
183
  pseudocode_output = gr.Textbox(
184
  label="Generated Pseudocode",
185
  lines=5
186
  )
187
 
188
  submit_btn.click(
189
+ fn=respond,
190
+ inputs=[cpp_input, gr.State(value=[]), gr.Slider(minimum=10, maximum=1000, value=50, step=1, visible=False)],
191
  outputs=pseudocode_output
192
  )
193
 
194
+ if __name__ == "__main__":
195
+ demo.launch()