arya-ai-model commited on
Commit
c2d0dc7
·
1 Parent(s): 59e3ffd

fixing app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -2,27 +2,33 @@ import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
6
 
7
- # Set a writable cache directory
8
  os.environ["HF_HOME"] = "/tmp/huggingface"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
10
 
11
  # Model setup
12
  MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.bfloat16
15
 
16
- # Load model and tokenizer
 
 
 
 
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
  model = AutoModelForCausalLM.from_pretrained(
19
- MODEL_NAME, torch_dtype=DTYPE, device_map="auto"
 
 
 
20
  )
21
 
22
- # Set up generation config
23
- generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
24
- generation_config.pad_token_id = generation_config.eos_token_id
25
- generation_config.use_cache = True # Speed up decoding
26
 
27
  # FastAPI app
28
  app = FastAPI()
@@ -30,28 +36,26 @@ app = FastAPI()
30
  # Request payload
31
  class TextGenerationRequest(BaseModel):
32
  prompt: str
33
- max_tokens: int = 512 # Default to 512 for better performance
34
 
35
  @app.post("/generate")
36
  async def generate_text(request: TextGenerationRequest):
37
  try:
38
- # Tokenize input and move tensors to the correct device
39
  inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
40
 
41
- # Use no_grad() for faster inference
42
  with torch.no_grad():
43
  outputs = model.generate(
44
  **inputs,
45
  max_new_tokens=request.max_tokens,
46
- do_sample=True, # Enables sampling (use False for deterministic results)
47
- temperature=0.7, # Adjust for creativity (lower = more conservative)
48
- top_k=50, # Consider top 50 token choices
49
- top_p=0.9, # Nucleus sampling (reduces unlikely words)
50
- repetition_penalty=1.1, # Prevents looping responses
 
51
  )
52
 
53
- # Decode generated tokens
54
- result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
55
  return {"generated_text": result}
56
 
57
  except Exception as e:
 
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
6
 
7
+ # Set cache directory
8
  os.environ["HF_HOME"] = "/tmp/huggingface"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
10
 
11
  # Model setup
12
  MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
14
 
15
+ # Load 4-bit quantized model (for speed & efficiency)
16
+ bnb_config = BitsAndBytesConfig(
17
+ load_in_4bit=True, # Enable 4-bit inference
18
+ bnb_4bit_compute_dtype=torch.float16,
19
+ bnb_4bit_use_double_quant=True,
20
+ )
21
+
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
  model = AutoModelForCausalLM.from_pretrained(
24
+ MODEL_NAME,
25
+ quantization_config=bnb_config,
26
+ device_map="auto",
27
+ attn_implementation="flash_attention_2" # Enables Flash Attention
28
  )
29
 
30
+ # Compile for even faster inference (PyTorch 2.0+)
31
+ model = torch.compile(model)
 
 
32
 
33
  # FastAPI app
34
  app = FastAPI()
 
36
  # Request payload
37
  class TextGenerationRequest(BaseModel):
38
  prompt: str
39
+ max_tokens: int = 512 # Default to 512
40
 
41
  @app.post("/generate")
42
  async def generate_text(request: TextGenerationRequest):
43
  try:
 
44
  inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
45
 
 
46
  with torch.no_grad():
47
  outputs = model.generate(
48
  **inputs,
49
  max_new_tokens=request.max_tokens,
50
+ do_sample=True,
51
+ temperature=0.7,
52
+ top_k=50,
53
+ top_p=0.9,
54
+ repetition_penalty=1.05,
55
+ use_cache=True,
56
  )
57
 
58
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
59
  return {"generated_text": result}
60
 
61
  except Exception as e: