# generation_fast.py | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
class CodeGenerator: | |
def __init__(self, model_name="S-Dreamer/PyCodeT5"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
def generate_code(self, nl_input, max_length=512, num_beams=5, early_stopping=True): | |
inputs = self.tokenizer(nl_input, return_tensors="pt").to(self.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_length=max_length, | |
num_beams=num_beams, | |
early_stopping=early_stopping, | |
no_repeat_ngram_size=2, # Prevents repetition | |
length_penalty=1.0, # Adjust length penalty | |
temperature=1.0, # Adjust temperature for diversity | |
) | |
generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_code | |
if __name__ == "__main__": | |
generator = CodeGenerator() | |
nl_input = "Write a Python function to reverse a string." | |
generated_code = generator.generate_code(nl_input) | |
print(generated_code) |