burhan112 commited on
Commit
b9eae41
·
verified ·
1 Parent(s): 6f656b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -140,34 +140,42 @@ sp_code = spm.SentencePieceProcessor(model_file="code_tokenizer.model") # C
140
  sp_pseudo = spm.SentencePieceProcessor(model_file="pseudocode_tokenizer.model") # Pseudocode tokenizer for output
141
 
142
  # Load the model trained for C++ to pseudocode
143
- model_path = "c2p.pth" # Assuming retrained model for C++ to pseudocode
144
  model = torch.load(model_path, map_location=device, weights_only=False)
145
  model.eval()
146
  model = model.to(device)
147
 
148
- # Function to generate pseudocode from C++ code
149
  def generate_pseudocode(cpp_code, max_len=500):
150
  model.eval()
151
  src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ input
152
  tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <BOS> token (ID=2)
153
 
154
  generated_tokens = [2] # Start with <BOS>
 
 
 
 
155
  with torch.no_grad():
156
- for _ in range(max_len):
157
  output = model(src, tgt)
158
  next_token = output[:, -1, :].argmax(-1).item()
159
  generated_tokens.append(next_token)
160
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
161
- if next_token == 3: # <EOS> token (ID=3)
 
 
 
 
162
  break
 
163
 
164
- response = sp_pseudo.decode_ids(generated_tokens) # Decode using pseudocode tokenizer
165
- return response
166
 
167
- # Gradio interface function
168
  def generate_output(cpp_code):
169
- pseudocode = generate_pseudocode(cpp_code)
170
- return pseudocode
171
 
172
  # Gradio UI setup
173
  with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
@@ -187,7 +195,8 @@ with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
187
  generate_btn.click(
188
  fn=generate_output,
189
  inputs=[cpp_input],
190
- outputs=pseudocode_output
 
191
  )
192
 
193
  demo.launch()
 
140
  sp_pseudo = spm.SentencePieceProcessor(model_file="pseudocode_tokenizer.model") # Pseudocode tokenizer for output
141
 
142
  # Load the model trained for C++ to pseudocode
143
+ model_path = "c2p.pth" # Ensure this is the correct model for C++ to pseudocode
144
  model = torch.load(model_path, map_location=device, weights_only=False)
145
  model.eval()
146
  model = model.to(device)
147
 
148
+ # Function to generate pseudocode from C++ code with streaming
149
  def generate_pseudocode(cpp_code, max_len=500):
150
  model.eval()
151
  src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ input
152
  tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <BOS> token (ID=2)
153
 
154
  generated_tokens = [2] # Start with <BOS>
155
+ eos_id = sp_pseudo.eos_id() # Dynamically get <EOS> ID from tokenizer
156
+ print(f"Input C++ tokens: {sp_code.encode_as_ids(cpp_code)}") # Debug input
157
+ print(f"Using EOS ID: {eos_id}") # Debug EOS ID
158
+
159
  with torch.no_grad():
160
+ for i in range(max_len):
161
  output = model(src, tgt)
162
  next_token = output[:, -1, :].argmax(-1).item()
163
  generated_tokens.append(next_token)
164
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
165
+ response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
166
+ print(f"Step {i}: Next token = {next_token}, Partial output = {response}") # Debug step
167
+ yield response # Yield partial output for streaming
168
+ if next_token == eos_id: # Stop at <EOS>
169
+ print("EOS detected, stopping generation.")
170
  break
171
+ print("Generation complete or max length reached.")
172
 
173
+ yield response # Final output
 
174
 
175
+ # Gradio interface function with streaming
176
  def generate_output(cpp_code):
177
+ for response in generate_pseudocode(cpp_code, max_len=500):
178
+ yield response
179
 
180
  # Gradio UI setup
181
  with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
 
195
  generate_btn.click(
196
  fn=generate_output,
197
  inputs=[cpp_input],
198
+ outputs=pseudocode_output,
199
+ _js="() => [document.querySelector('#cpp_input textarea').value]" # Ensure input is passed correctly
200
  )
201
 
202
  demo.launch()