|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
class CodeGenerator: |
|
def __init__(self, model_name): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def generate_code(self, nl_input, max_length=256, num_beams=4, early_stopping=True): |
|
inputs = self.tokenizer(nl_input, return_tensors="pt").to(self.device) |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_beams=num_beams, |
|
early_stopping=early_stopping, |
|
) |
|
generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_code |
|
|
|
if __name__ == "__main__": |
|
model_name = "S-Dreamer/PyCodeT5" |
|
generator = CodeGenerator(model_name) |
|
|
|
nl_input = "Write a Python function to calculate the factorial of a number." |
|
generated_code = generator.generate_code(nl_input) |
|
print(generated_code) |