Spaces:
Build error
Build error
import os | |
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
from typing import List, Optional | |
from datasets import load_dataset | |
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling | |
import json | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Setup cache directory | |
os.makedirs("/app/cache", exist_ok=True) | |
os.environ['TRANSFORMERS_CACHE'] = "/app/cache" | |
# Pydantic models for request/response | |
class GenerateRequest(BaseModel): | |
text: str | |
max_length: Optional[int] = 512 | |
temperature: Optional[float] = 0.7 | |
num_return_sequences: Optional[int] = 1 | |
class GenerateResponse(BaseModel): | |
generated_text: List[str] | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
gpu_available: bool | |
device: str | |
class TrainRequest(BaseModel): | |
dataset_path: str | |
num_epochs: Optional[int] = 3 | |
batch_size: Optional[int] = 4 | |
learning_rate: Optional[float] = 2e-5 | |
class TrainResponse(BaseModel): | |
status: str | |
message: str | |
# Add training status tracking | |
class TrainingStatus: | |
def __init__(self): | |
self.is_training = False | |
self.current_epoch = 0 | |
self.current_loss = None | |
self.status = "idle" | |
training_status = TrainingStatus() | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Medical LLaMA API", | |
description="API for medical text generation using fine-tuned LLaMA model", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
async def root(): | |
""" | |
Root endpoint to check API health and model status | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
return HealthResponse( | |
status="online", | |
model_loaded=model is not None, | |
gpu_available=torch.cuda.is_available(), | |
device=device | |
) | |
async def generate_text(request: GenerateRequest): | |
""" | |
Generate medical text based on input prompt | |
Parameters: | |
- text: Input text prompt | |
- max_length: Maximum length of generated text | |
- temperature: Sampling temperature (0.0 to 1.0) | |
- num_return_sequences: Number of sequences to generate | |
Returns: | |
- List of generated text sequences | |
""" | |
try: | |
if model is None or tokenizer is None: | |
raise HTTPException(status_code=500, detail="Model not loaded") | |
inputs = tokenizer( | |
request.text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=request.max_length | |
).to(model.device) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
inputs.input_ids, | |
max_length=request.max_length, | |
num_return_sequences=request.num_return_sequences, | |
temperature=request.temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
generated_texts = [ | |
tokenizer.decode(g, skip_special_tokens=True) | |
for g in generated_ids | |
] | |
return GenerateResponse(generated_text=generated_texts) | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
""" | |
Check the health status of the API and model | |
""" | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"gpu_available": torch.cuda.is_available(), | |
"device": "cuda" if torch.cuda.is_available() else "cpu" | |
} | |
async def startup_event(): | |
logger.info("Starting up application...") | |
try: | |
global tokenizer, model | |
tokenizer, model = init_model() | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks): | |
""" | |
Start model training with the specified dataset | |
Parameters: | |
- dataset_path: Path to the JSON dataset file | |
- num_epochs: Number of training epochs | |
- batch_size: Training batch size | |
- learning_rate: Learning rate for training | |
""" | |
if training_status.is_training: | |
raise HTTPException(status_code=400, detail="Training is already in progress") | |
try: | |
# Verify dataset exists | |
if not os.path.exists(request.dataset_path): | |
raise HTTPException(status_code=404, detail="Dataset file not found") | |
# Start training in background | |
background_tasks.add_task( | |
run_training, | |
request.dataset_path, | |
request.num_epochs, | |
request.batch_size, | |
request.learning_rate | |
) | |
return TrainResponse( | |
status="started", | |
message="Training started in background" | |
) | |
except Exception as e: | |
logger.error(f"Training setup error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_training_status(): | |
""" | |
Get current training status | |
""" | |
return { | |
"is_training": training_status.is_training, | |
"current_epoch": training_status.current_epoch, | |
"current_loss": training_status.current_loss, | |
"status": training_status.status | |
} | |
# Add training function | |
async def run_training(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float): | |
global model, tokenizer, training_status | |
try: | |
training_status.is_training = True | |
training_status.status = "loading_dataset" | |
# Load dataset | |
dataset = load_dataset("json", data_files=dataset_path) | |
training_status.status = "preprocessing" | |
# Preprocess function | |
def preprocess_function(examples): | |
return tokenizer( | |
examples["text"], | |
truncation=True, | |
padding="max_length", | |
max_length=512 | |
) | |
# Tokenize dataset | |
tokenized_dataset = dataset.map( | |
preprocess_function, | |
batched=True, | |
remove_columns=dataset["train"].column_names | |
) | |
training_status.status = "training" | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir=f"{model_output_path}/checkpoints", | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=4, | |
num_train_epochs=num_epochs, | |
learning_rate=learning_rate, | |
fp16=True, | |
save_steps=500, | |
logging_steps=100, | |
) | |
# Initialize trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"], | |
data_collator=DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False | |
), | |
) | |
# Training callback to update status | |
class TrainingCallback(trainer.callback_handler): | |
def on_epoch_begin(self, args, state, control, **kwargs): | |
training_status.current_epoch = state.epoch | |
def on_log(self, args, state, control, logs=None, **kwargs): | |
if logs: | |
training_status.current_loss = logs.get("loss", None) | |
trainer.add_callback(TrainingCallback) | |
# Start training | |
trainer.train() | |
# Save the model | |
training_status.status = "saving" | |
model.save_pretrained(model_output_path) | |
tokenizer.save_pretrained(model_output_path) | |
training_status.status = "completed" | |
logger.info("Training completed successfully") | |
except Exception as e: | |
training_status.status = f"failed: {str(e)}" | |
logger.error(f"Training error: {str(e)}") | |
raise | |
finally: | |
training_status.is_training = False | |
# Update model initialization | |
def init_model(): | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Loading model on device: {device}") | |
# Try to load fine-tuned model if it exists | |
if os.path.exists(model_output_path): | |
tokenizer = AutoTokenizer.from_pretrained(model_output_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_output_path, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" | |
) | |
else: | |
# Load base model if no fine-tuned model exists | |
model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" | |
) | |
return tokenizer, model | |
except Exception as e: | |
logger.error(f"Model initialization error: {str(e)}") | |
raise |