deepseek-7b / app /model.py
arya-ai-model's picture
First commit
d6f9a33
raw
history blame
722 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from app.config import MODEL_NAME, DEVICE
class DeepSeekModel:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=os.getenv("HF_TOKEN"))
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
)
def generate(self, prompt, max_length=256):
inputs = self.tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = self.model.generate(**inputs, max_length=max_length)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
model = DeepSeekModel()