burhan112 commited on
Commit
3a24c67
·
verified ·
1 Parent(s): e393094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -59
app.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  import sentencepiece as spm
5
  import math
6
 
7
- # Transformer class definitions (unchanged)
8
  class MultiHeadAttention(nn.Module):
9
  def __init__(self, d_model, num_heads):
10
  super(MultiHeadAttention, self).__init__()
@@ -132,82 +132,55 @@ class Transformer(nn.Module):
132
  output = self.fc(dec_output)
133
  return output
134
 
135
- # Device setup
136
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
 
138
  # Load tokenizers
139
- sp_code = spm.SentencePieceProcessor(model_file="code_tokenizer.model") # C++ tokenizer for input
140
- sp_pseudo = spm.SentencePieceProcessor(model_file="pseudocode_tokenizer.model") # Pseudocode tokenizer for output
141
 
142
- # Load the model trained for C++ to pseudocode
143
- model_path = "c2p.pth" # Ensure this is the correct model for C++ to pseudocode
144
  model = torch.load(model_path, map_location=device, weights_only=False)
145
  model.eval()
146
  model = model.to(device)
147
 
148
- # Function to generate pseudocode from C++ code with streaming
149
- def generate_pseudocode(cpp_code, max_len=500):
150
  model.eval()
151
- src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ input
152
- tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <BOS> token (ID=2)
153
-
154
- generated_tokens = [2] # Start with <BOS>
155
- eos_id = sp_pseudo.eos_id() # Dynamically get <EOS> ID from tokenizer
156
- print(f"Input C++ tokens: {sp_code.encode_as_ids(cpp_code)}") # Debug input
157
- print(f"Using EOS ID: {eos_id}") # Debug EOS ID
158
 
 
 
159
  with torch.no_grad():
160
- for i in range(max_len):
161
  output = model(src, tgt)
162
  next_token = output[:, -1, :].argmax(-1).item()
163
  generated_tokens.append(next_token)
164
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
165
  response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
166
- print(f"Step {i}: Next token = {next_token}, Partial output = {response}") # Debug step
167
- yield response # Yield partial output for streaming
168
- if next_token == eos_id: # Stop at <EOS>
169
- print("EOS detected, stopping generation.")
170
  break
171
- print("Generation complete or max length reached.")
172
-
173
  yield response # Final output
174
 
175
- # Gradio interface function with streaming
176
- def generate_output(cpp_code):
177
- for response in generate_pseudocode(cpp_code, max_len=500):
178
  yield response
179
 
180
- # Gradio UI setup
181
- with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
182
- gr.Markdown("## C++ to Pseudocode Converter")
183
- gr.Markdown("Enter C++ code below to generate pseudocode.")
184
- cpp_input = gr.Textbox(
185
- label="C++ Code",
186
- placeholder="e.g., 'int main() { int n; cin >> n; }'",
187
- lines=5
188
- )
189
- generate_btn = gr.Button("Generate", variant="primary", elem_classes="btn-blue")
190
- pseudocode_output = gr.Textbox(
191
- label="Generated Pseudocode",
192
- lines=5
193
- )
194
-
195
- generate_btn.click(
196
- fn=generate_output,
197
- inputs=[cpp_input],
198
- outputs=pseudocode_output
199
- )
200
-
201
- demo.launch()
202
-
203
- # Custom CSS
204
- demo.css = """
205
- .btn-blue {
206
- background-color: #007bff;
207
- color: white;
208
- border: none;
209
- }
210
- .btn-blue:hover {
211
- background-color: #0056b3;
212
- }
213
- """
 
4
  import sentencepiece as spm
5
  import math
6
 
7
+ # Define Transformer components (unchanged)
8
  class MultiHeadAttention(nn.Module):
9
  def __init__(self, d_model, num_heads):
10
  super(MultiHeadAttention, self).__init__()
 
132
  output = self.fc(dec_output)
133
  return output
134
 
135
+ # Set device
136
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
 
138
  # Load tokenizers
139
+ sp_pseudo = spm.SentencePieceProcessor(model_file="pseudocode_tokenizer.model") # For decoding pseudocode (target)
140
+ sp_code = spm.SentencePieceProcessor(model_file="code_tokenizer.model") # For encoding C++ (source)
141
 
142
+ # Load the full saved model (architecture + weights)
143
+ model_path = "c2p.pth"
144
  model = torch.load(model_path, map_location=device, weights_only=False)
145
  model.eval()
146
  model = model.to(device)
147
 
148
+ def generate_pseudocode(cpp_code, max_len):
149
+ """Generate pseudocode from C++ code with streaming output."""
150
  model.eval()
151
+ src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ code
152
+ tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
 
 
 
 
 
153
 
154
+ generated_tokens = [2] # Start with <START>
155
+ response = ""
156
  with torch.no_grad():
157
+ for _ in range(max_len):
158
  output = model(src, tgt)
159
  next_token = output[:, -1, :].argmax(-1).item()
160
  generated_tokens.append(next_token)
161
  tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
162
  response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
163
+ yield response # Yield partial output
164
+ if next_token == 3: # <END>=3 (adjust if your EOS ID differs)
 
 
165
  break
 
 
166
  yield response # Final output
167
 
168
+ def respond(message, history, max_tokens):
169
+ """Wrapper for Gradio interface."""
170
+ for response in generate_pseudocode(message, max_tokens):
171
  yield response
172
 
173
+ # Gradio interface
174
+ demo = gr.ChatInterface(
175
+ respond,
176
+ chatbot=gr.Chatbot(label="C++ to Pseudocode Generator"),
177
+ textbox=gr.Textbox(placeholder="Enter C++ code (e.g., 'int x = 5; for(int i=0; i<x; i++) cout << i;')", label="C++ Code"),
178
+ additional_inputs=[
179
+ gr.Slider(minimum=10, maximum=1000, value=50, step=1, label="Max tokens"),
180
+ ],
181
+ title="C++ to Pseudocode Transformer",
182
+ description="Convert C++ code to pseudocode using a custom transformer trained on the SPoC dataset.",
183
+ )
184
+
185
+ if __name__ == "__main__":
186
+ demo.launch()