arya-ai-model commited on
Commit
98db4b3
·
1 Parent(s): c2d0dc7

fixing app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -30
app.py CHANGED
@@ -1,34 +1,35 @@
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
- from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
6
 
7
- # Set cache directory
8
  os.environ["HF_HOME"] = "/tmp/huggingface"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
10
 
11
  # Model setup
12
- MODEL_NAME = "deepseek-ai/deepseek-llm-7b-base"
13
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- # Load 4-bit quantized model (for speed & efficiency)
16
- bnb_config = BitsAndBytesConfig(
17
- load_in_4bit=True, # Enable 4-bit inference
18
  bnb_4bit_compute_dtype=torch.float16,
19
- bnb_4bit_use_double_quant=True,
20
  )
21
 
 
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
  model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_NAME,
25
- quantization_config=bnb_config,
26
- device_map="auto",
27
- attn_implementation="flash_attention_2" # Enables Flash Attention
28
  )
29
 
30
- # Compile for even faster inference (PyTorch 2.0+)
31
- model = torch.compile(model)
 
32
 
33
  # FastAPI app
34
  app = FastAPI()
@@ -36,27 +37,14 @@ app = FastAPI()
36
  # Request payload
37
  class TextGenerationRequest(BaseModel):
38
  prompt: str
39
- max_tokens: int = 512 # Default to 512
40
 
41
  @app.post("/generate")
42
  async def generate_text(request: TextGenerationRequest):
43
  try:
44
- inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
45
-
46
- with torch.no_grad():
47
- outputs = model.generate(
48
- **inputs,
49
- max_new_tokens=request.max_tokens,
50
- do_sample=True,
51
- temperature=0.7,
52
- top_k=50,
53
- top_p=0.9,
54
- repetition_penalty=1.05,
55
- use_cache=True,
56
- )
57
-
58
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
  return {"generated_text": result}
60
-
61
  except Exception as e:
62
  raise HTTPException(status_code=500, detail=str(e))
 
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel, Field
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
6
 
7
+ # Set a writable cache directory
8
  os.environ["HF_HOME"] = "/tmp/huggingface"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
10
 
11
  # Model setup
12
+ MODEL_NAME = "google/gemma-2b" # Smaller, CPU-friendly model
13
+ DEVICE = "cpu"
14
 
15
+ # 4-bit Quantization for CPU
16
+ quantization_config = BitsAndBytesConfig(
17
+ load_in_4bit=True,
18
  bnb_4bit_compute_dtype=torch.float16,
19
+ bnb_4bit_use_double_quant=True
20
  )
21
 
22
+ # Load model & tokenizer
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
  model = AutoModelForCausalLM.from_pretrained(
25
+ MODEL_NAME,
26
+ quantization_config=quantization_config,
27
+ device_map="cpu"
 
28
  )
29
 
30
+ # Set generation config
31
+ model.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
32
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
33
 
34
  # FastAPI app
35
  app = FastAPI()
 
37
  # Request payload
38
  class TextGenerationRequest(BaseModel):
39
  prompt: str
40
+ max_tokens: int = Field(default=100, ge=1, le=512) # Prevent too large token requests
41
 
42
  @app.post("/generate")
43
  async def generate_text(request: TextGenerationRequest):
44
  try:
45
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(DEVICE)
46
+ outputs = model.generate(**inputs, max_new_tokens=request.max_tokens, do_sample=True)
 
 
 
 
 
 
 
 
 
 
 
 
47
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
  return {"generated_text": result}
 
49
  except Exception as e:
50
  raise HTTPException(status_code=500, detail=str(e))