AbdulHadi806 commited on
Commit
be347f9
·
verified ·
1 Parent(s): b1c1a11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -5
app.py CHANGED
@@ -5,13 +5,47 @@ from transformers import T5ForConditionalGeneration, RobertaTokenizer
5
  quantized_model = T5ForConditionalGeneration.from_pretrained("AbdulHadi806/codet5-finetuned-latest-quantized")
6
  tokenizer = RobertaTokenizer.from_pretrained("AbdulHadi806/codet5-finetuned-latest-quantized")
7
 
8
- def inference(input_text):
9
- inputs = tokenizer(input_text, return_tensors="pt")
10
- outputs = quantized_model.generate(**inputs)
11
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Create Gradio interface
14
- iface = gr.Interface(fn=inference, inputs="text", outputs="text")
15
 
16
  # Launch the interface
17
  iface.launch()
 
5
  quantized_model = T5ForConditionalGeneration.from_pretrained("AbdulHadi806/codet5-finetuned-latest-quantized")
6
  tokenizer = RobertaTokenizer.from_pretrained("AbdulHadi806/codet5-finetuned-latest-quantized")
7
 
8
+ def generate_code(input_text):
9
+ print(input_text)
10
+ input_ids = tokenizer(input_text, return_tensors='pt', padding="max_length", truncation=True, max_length=128).input_ids.to(model.device)
11
+ outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
12
+ predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
+ cleaned_code = clean_generated_code(postprocess_output(predicted_text))
14
+
15
+ return cleaned_code
16
+
17
+ def preprocess_infer_input(text):
18
+ # Assuming the input is already a string, we don't need to access it as a dictionary
19
+ return f"latex: {text}"
20
+
21
+ def clean_generated_code(generated_code):
22
+ # Remove unwanted parts
23
+ print(':::generated_code::::', generated_code)
24
+ cleaned_code = generated_code.replace('*convert(latex, python.code)', '').strip()
25
+
26
+ # Optionally, format the code for better readability
27
+ cleaned_code = cleaned_code.replace('\n', '\n').replace(' ', ' ') # Adjust spacing if needed
28
+
29
+ return cleaned_code
30
+
31
+ def generate_solution(input_text):
32
+ input_text = preprocess_infer_input(input_text)
33
+ print(input_text)
34
+
35
+ input_ids = tokenizer(input_text, return_tensors='pt', padding="max_length", truncation=True, max_length=128).input_ids
36
+ input_ids = input_ids.to(model.device)
37
+
38
+ with torch.no_grad():
39
+ outputs = quantized_model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
40
+
41
+ predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ cleaned_code = clean_generated_code(postprocess_output(predicted_text))
43
+ return cleaned_code
44
+
45
+
46
 
47
  # Create Gradio interface
48
+ iface = gr.Interface(fn=generate_solution, inputs="text", outputs="text")
49
 
50
  # Launch the interface
51
  iface.launch()