deepseek-7b / app.py
arya-ai-model's picture
fixing app.py
c2d0dc7
raw
history blame
1.88 kB
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
# Set cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
# Model setup
MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load 4-bit quantized model (for speed & efficiency)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Enable 4-bit inference
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="flash_attention_2" # Enables Flash Attention
)
# Compile for even faster inference (PyTorch 2.0+)
model = torch.compile(model)
# FastAPI app
app = FastAPI()
# Request payload
class TextGenerationRequest(BaseModel):
prompt: str
max_tokens: int = 512 # Default to 512
@app.post("/generate")
async def generate_text(request: TextGenerationRequest):
try:
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.05,
use_cache=True,
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))