Spaces:
Runtime error
Runtime error
File size: 2,142 Bytes
0512849 102a3b0 59e3ffd 102a3b0 59e3ffd 102a3b0 59e3ffd 102a3b0 59e3ffd 102a3b0 59e3ffd 102a3b0 59e3ffd 102a3b0 59e3ffd 102a3b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
# Set a writable cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
# Model setup
MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.bfloat16
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=DTYPE, device_map="auto"
)
# Set up generation config
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
generation_config.pad_token_id = generation_config.eos_token_id
generation_config.use_cache = True # Speed up decoding
# FastAPI app
app = FastAPI()
# Request payload
class TextGenerationRequest(BaseModel):
prompt: str
max_tokens: int = 512 # Default to 512 for better performance
@app.post("/generate")
async def generate_text(request: TextGenerationRequest):
try:
# Tokenize input and move tensors to the correct device
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
# Use no_grad() for faster inference
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
do_sample=True, # Enables sampling (use False for deterministic results)
temperature=0.7, # Adjust for creativity (lower = more conservative)
top_k=50, # Consider top 50 token choices
top_p=0.9, # Nucleus sampling (reduces unlikely words)
repetition_penalty=1.1, # Prevents looping responses
)
# Decode generated tokens
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"generated_text": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|