deepseek-7b / app.py
arya-ai-model's picture
fixing app.py
59e3ffd
raw
history blame
2.14 kB
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
# Set a writable cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
# Model setup
MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.bfloat16
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=DTYPE, device_map="auto"
)
# Set up generation config
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
generation_config.pad_token_id = generation_config.eos_token_id
generation_config.use_cache = True # Speed up decoding
# FastAPI app
app = FastAPI()
# Request payload
class TextGenerationRequest(BaseModel):
prompt: str
max_tokens: int = 512 # Default to 512 for better performance
@app.post("/generate")
async def generate_text(request: TextGenerationRequest):
try:
# Tokenize input and move tensors to the correct device
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
# Use no_grad() for faster inference
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
do_sample=True, # Enables sampling (use False for deterministic results)
temperature=0.7, # Adjust for creativity (lower = more conservative)
top_k=50, # Consider top 50 token choices
top_p=0.9, # Nucleus sampling (reduces unlikely words)
repetition_penalty=1.1, # Prevents looping responses
)
# Decode generated tokens
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"generated_text": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))