Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -140,34 +140,42 @@ sp_code = spm.SentencePieceProcessor(model_file="code_tokenizer.model") # C
|
|
140 |
sp_pseudo = spm.SentencePieceProcessor(model_file="pseudocode_tokenizer.model") # Pseudocode tokenizer for output
|
141 |
|
142 |
# Load the model trained for C++ to pseudocode
|
143 |
-
model_path = "c2p.pth" #
|
144 |
model = torch.load(model_path, map_location=device, weights_only=False)
|
145 |
model.eval()
|
146 |
model = model.to(device)
|
147 |
|
148 |
-
# Function to generate pseudocode from C++ code
|
149 |
def generate_pseudocode(cpp_code, max_len=500):
|
150 |
model.eval()
|
151 |
src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ input
|
152 |
tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <BOS> token (ID=2)
|
153 |
|
154 |
generated_tokens = [2] # Start with <BOS>
|
|
|
|
|
|
|
|
|
155 |
with torch.no_grad():
|
156 |
-
for
|
157 |
output = model(src, tgt)
|
158 |
next_token = output[:, -1, :].argmax(-1).item()
|
159 |
generated_tokens.append(next_token)
|
160 |
tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
break
|
|
|
163 |
|
164 |
-
response
|
165 |
-
return response
|
166 |
|
167 |
-
# Gradio interface function
|
168 |
def generate_output(cpp_code):
|
169 |
-
|
170 |
-
|
171 |
|
172 |
# Gradio UI setup
|
173 |
with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
|
@@ -187,7 +195,8 @@ with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
|
|
187 |
generate_btn.click(
|
188 |
fn=generate_output,
|
189 |
inputs=[cpp_input],
|
190 |
-
outputs=pseudocode_output
|
|
|
191 |
)
|
192 |
|
193 |
demo.launch()
|
|
|
140 |
sp_pseudo = spm.SentencePieceProcessor(model_file="pseudocode_tokenizer.model") # Pseudocode tokenizer for output
|
141 |
|
142 |
# Load the model trained for C++ to pseudocode
|
143 |
+
model_path = "c2p.pth" # Ensure this is the correct model for C++ to pseudocode
|
144 |
model = torch.load(model_path, map_location=device, weights_only=False)
|
145 |
model.eval()
|
146 |
model = model.to(device)
|
147 |
|
148 |
+
# Function to generate pseudocode from C++ code with streaming
|
149 |
def generate_pseudocode(cpp_code, max_len=500):
|
150 |
model.eval()
|
151 |
src = torch.tensor([sp_code.encode_as_ids(cpp_code)], dtype=torch.long, device=device) # Tokenize C++ input
|
152 |
tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <BOS> token (ID=2)
|
153 |
|
154 |
generated_tokens = [2] # Start with <BOS>
|
155 |
+
eos_id = sp_pseudo.eos_id() # Dynamically get <EOS> ID from tokenizer
|
156 |
+
print(f"Input C++ tokens: {sp_code.encode_as_ids(cpp_code)}") # Debug input
|
157 |
+
print(f"Using EOS ID: {eos_id}") # Debug EOS ID
|
158 |
+
|
159 |
with torch.no_grad():
|
160 |
+
for i in range(max_len):
|
161 |
output = model(src, tgt)
|
162 |
next_token = output[:, -1, :].argmax(-1).item()
|
163 |
generated_tokens.append(next_token)
|
164 |
tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
|
165 |
+
response = sp_pseudo.decode_ids(generated_tokens) # Decode to pseudocode
|
166 |
+
print(f"Step {i}: Next token = {next_token}, Partial output = {response}") # Debug step
|
167 |
+
yield response # Yield partial output for streaming
|
168 |
+
if next_token == eos_id: # Stop at <EOS>
|
169 |
+
print("EOS detected, stopping generation.")
|
170 |
break
|
171 |
+
print("Generation complete or max length reached.")
|
172 |
|
173 |
+
yield response # Final output
|
|
|
174 |
|
175 |
+
# Gradio interface function with streaming
|
176 |
def generate_output(cpp_code):
|
177 |
+
for response in generate_pseudocode(cpp_code, max_len=500):
|
178 |
+
yield response
|
179 |
|
180 |
# Gradio UI setup
|
181 |
with gr.Blocks(title="C++ to Pseudocode Transformer") as demo:
|
|
|
195 |
generate_btn.click(
|
196 |
fn=generate_output,
|
197 |
inputs=[cpp_input],
|
198 |
+
outputs=pseudocode_output,
|
199 |
+
_js="() => [document.querySelector('#cpp_input textarea').value]" # Ensure input is passed correctly
|
200 |
)
|
201 |
|
202 |
demo.launch()
|