S-Dreamer commited on
Commit
57b3195
·
verified ·
1 Parent(s): 8302e50

Update generation_fast.py

Browse files
Files changed (1) hide show
  1. generation_fast.py +8 -7
generation_fast.py CHANGED
@@ -1,29 +1,30 @@
 
1
  import torch
2
- import torch.nn as nn
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  class CodeGenerator:
6
- def __init__(self, model_name):
7
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  self.model.to(self.device)
11
 
12
- def generate_code(self, nl_input, max_length=256, num_beams=4, early_stopping=True):
13
  inputs = self.tokenizer(nl_input, return_tensors="pt").to(self.device)
14
  outputs = self.model.generate(
15
  **inputs,
16
  max_length=max_length,
17
  num_beams=num_beams,
18
  early_stopping=early_stopping,
 
 
 
19
  )
20
  generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
21
  return generated_code
22
 
23
  if __name__ == "__main__":
24
- model_name = "S-Dreamer/PyCodeT5"
25
- generator = CodeGenerator(model_name)
26
-
27
- nl_input = "Write a Python function to calculate the factorial of a number."
28
  generated_code = generator.generate_code(nl_input)
29
  print(generated_code)
 
1
+ # generation_fast.py
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  class CodeGenerator:
6
+ def __init__(self, model_name="S-Dreamer/PyCodeT5"):
7
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  self.model.to(self.device)
11
 
12
+ def generate_code(self, nl_input, max_length=512, num_beams=5, early_stopping=True):
13
  inputs = self.tokenizer(nl_input, return_tensors="pt").to(self.device)
14
  outputs = self.model.generate(
15
  **inputs,
16
  max_length=max_length,
17
  num_beams=num_beams,
18
  early_stopping=early_stopping,
19
+ no_repeat_ngram_size=2, # Prevents repetition
20
+ length_penalty=1.0, # Adjust length penalty
21
+ temperature=1.0, # Adjust temperature for diversity
22
  )
23
  generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
24
  return generated_code
25
 
26
  if __name__ == "__main__":
27
+ generator = CodeGenerator()
28
+ nl_input = "Write a Python function to reverse a string."
 
 
29
  generated_code = generator.generate_code(nl_input)
30
  print(generated_code)