burhan112 commited on
Commit
12724c3
·
verified ·
1 Parent(s): a2b554f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -21
app.py CHANGED
@@ -145,42 +145,69 @@ model = torch.load(model_path, map_location=device, weights_only=False)
145
  model.eval()
146
  model = model.to(device)
147
 
148
- def generate_pseudocode(cpp_code, max_len):
149
  """Generate pseudocode from C++ code with streaming output."""
150
  model.eval()
151
  src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ code
152
  tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
153
 
154
  generated_tokens = [2] # Start with <START>
155
- response = ""
 
 
 
156
  with torch.no_grad():
157
- for _ in range(max_len):
158
  output = model(src, tgt)
159
  next_token = output[:, -1, :].argmax(-1).item()
160
  generated_tokens.append(next_token)
161
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
162
  response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
 
163
  yield response # Yield partial output
164
- if next_token == 3: # <END>=3 (adjust if your EOS ID differs)
 
165
  break
 
 
166
  yield response # Final output
167
 
168
- def respond(message, history, max_tokens):
169
- """Wrapper for Gradio interface."""
170
- for response in generate_pseudocode(message, max_tokens):
171
  yield response
172
 
173
- # Gradio interface
174
- demo = gr.ChatInterface(
175
- respond,
176
- chatbot=gr.Chatbot(label="C++ to Pseudocode Generator"),
177
- textbox=gr.Textbox(placeholder="Enter C++ code (e.g., 'int x = 5; for(int i=0; i<x; i++) cout << i;')", label="C++ Code"),
178
- additional_inputs=[
179
- gr.Slider(minimum=10, maximum=1000, value=50, step=1, label="Max tokens"),
180
- ],
181
- title="C++ to Pseudocode Transformer",
182
- description="Convert C++ code to pseudocode using a custom transformer trained on the SPoC dataset.",
183
- )
184
-
185
- if __name__ == "__main__":
186
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  model.eval()
146
  model = model.to(device)
147
 
148
+ def generate_pseudocode(cpp_code, max_len=500):
149
  """Generate pseudocode from C++ code with streaming output."""
150
  model.eval()
151
  src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ code
152
  tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
153
 
154
  generated_tokens = [2] # Start with <START>
155
+ eos_id = sp_pseudo.eos_id() # Dynamically get <EOS> ID
156
+ print(f"Input C++ tokens: {sp_code.encode_as_ids(cpp_code)}") # Debug input
157
+ print(f"Using EOS ID: {eos_id}") # Debug EOS ID
158
+
159
  with torch.no_grad():
160
+ for i in range(max_len):
161
  output = model(src, tgt)
162
  next_token = output[:, -1, :].argmax(-1).item()
163
  generated_tokens.append(next_token)
164
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
165
  response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
166
+ print(f"Step {i}: Next token = {next_token}, Partial output = {response}") # Debug step
167
  yield response # Yield partial output
168
+ if next_token == eos_id: # Stop at <EOS>
169
+ print("EOS detected, stopping generation.")
170
  break
171
+ print("Generation complete or max length reached.")
172
+
173
  yield response # Final output
174
 
175
+ def generate_output(cpp_code):
176
+ """Wrapper for Gradio interface with streaming."""
177
+ for response in generate_pseudocode(cpp_code, max_len=500):
178
  yield response
179
 
180
+ # Gradio UI setup with Blocks
181
+ with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
182
+ gr.Markdown("## C++ to Pseudocode Converter")
183
+ gr.Markdown("Enter C++ code below and press Submit to generate pseudocode.")
184
+ cpp_input = gr.Textbox(
185
+ label="C++ Code",
186
+ placeholder="e.g., 'int x = 5; for(int i=0; i<x; i++) cout << i;'",
187
+ lines=5
188
+ )
189
+ submit_btn = gr.Button("Submit", variant="primary", elem_classes="btn-blue")
190
+ pseudocode_output = gr.Textbox(
191
+ label="Generated Pseudocode",
192
+ lines=5
193
+ )
194
+
195
+ submit_btn.click(
196
+ fn=generate_output,
197
+ inputs=[cpp_input],
198
+ outputs=pseudocode_output
199
+ )
200
+
201
+ demo.launch()
202
+
203
+ # Custom CSS
204
+ demo.css = """
205
+ .btn-blue {
206
+ background-color: #007bff;
207
+ color: white;
208
+ border: none;
209
+ }
210
+ .btn-blue:hover {
211
+ background-color: #0056b3;
212
+ }
213
+ """