|
from fastapi import FastAPI, HTTPException |
|
import base64 |
|
import os |
|
|
|
app = FastAPI() |
|
|
|
|
|
lora_weights = None |
|
|
|
@app.on_event("startup") |
|
def load_lora_weights(): |
|
global lora_weights |
|
lora_path = "./lora_file.pth" |
|
if os.path.exists(lora_path): |
|
with open(lora_path, "rb") as f: |
|
|
|
lora_weights = base64.b64encode(f.read()).decode("utf-8") |
|
print("LoRA weights loaded and preprocessed successfully.") |
|
else: |
|
raise HTTPException(status_code=500, detail="LoRA file not found.") |
|
|
|
@app.post("/modify-prompt") |
|
async def modify_prompt(prompt: str): |
|
global lora_weights |
|
if lora_weights is None: |
|
raise HTTPException(status_code=500, detail="LoRA weights not loaded.") |
|
|
|
extended_prompt = { |
|
"prompt": prompt, |
|
"lora": lora_weights |
|
} |
|
return extended_prompt |
|
|