manojapinew / app.py
ManojINaik's picture
Upload 4 files
93cf301 verified
raw
history blame
3.26 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import torch
from typing import Optional, List
app = FastAPI(title="LLM API", description="API for interacting with LLaMA model")
# Model configuration
class ModelConfig:
model_name = "ManojINaik/Strength_weakness" # Your fine-tuned model
device = "cuda" if torch.cuda.is_available() else "cpu"
max_length = 200
temperature = 0.7
# Request/Response models
class GenerateRequest(BaseModel):
prompt: str
history: Optional[List[str]] = []
system_prompt: Optional[str] = "You are a very powerful AI assistant."
max_length: Optional[int] = 200
temperature: Optional[float] = 0.7
class GenerateResponse(BaseModel):
response: str
# Global variables for model and tokenizer
model = None
tokenizer = None
generator = None
@app.on_event("startup")
async def load_model():
global model, tokenizer, generator
try:
print("Loading model and tokenizer...")
# Configure quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False
)
tokenizer = AutoTokenizer.from_pretrained(ModelConfig.model_name)
model = AutoModelForCausalLM.from_pretrained(
ModelConfig.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise e
@app.post("/generate/", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
if generator is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
# Format the prompt with system prompt and chat history
formatted_prompt = f"{request.system_prompt}\n\n"
for msg in request.history:
formatted_prompt += f"{msg}\n"
formatted_prompt += f"Human: {request.prompt}\nAssistant:"
# Generate response
outputs = generator(
formatted_prompt,
max_length=request.max_length,
temperature=request.temperature,
num_return_sequences=1,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Extract the generated text
generated_text = outputs[0]['generated_text']
# Remove the prompt from the response
response = generated_text.split("Assistant:")[-1].strip()
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
@app.get("/")
def root():
return {"message": "LLM API is running. Use /generate endpoint for text generation."}