import logging from fastapi import FastAPI, HTTPException from transformers import AutoModelForCausalLM, pipeline from peft import PeftModel, PeftConfig from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.client import MistralChain # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() # Global variables for model, tokenizer, and pipeline model = None tokenizer = None pipe = None mistral_chain = None @app.on_event("startup") async def load_model(): global model, tokenizer, pipe, mistral_chain try: logger.info("Loading PEFT configuration...") config = PeftConfig.from_pretrained("frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval") logger.info("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") logger.info("Loading PEFT model...") model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval") logger.info("Loading tokenizer...") tokenizer = MistralTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") logger.info("Creating MistralChain...") mistral_chain = MistralChain(model, tokenizer) logger.info("Creating pipeline...") pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer) logger.info("Model, tokenizer, and pipeline loaded successfully.") except ImportError as e: logger.error(f"Error importing required modules. Please check your installation: {e}") raise except Exception as e: logger.error(f"Error loading model or creating pipeline: {e}") raise @app.get("/") def home(): return {"message": "Hello World"} @app.get("/generate") async def generate(text: str): if not mistral_chain: raise HTTPException(status_code=503, detail="Model not loaded") try: output = mistral_chain.generate(text, max_tokens=100) return {"output": output} except Exception as e: logger.error(f"Error during text generation: {e}") raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)