DonImages commited on
Commit
0ab8732
·
verified ·
1 Parent(s): 5dbbf26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -1,40 +1,46 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  import base64
4
  import os
5
 
6
- app = FastAPI()
7
-
8
- # Middleware to handle CORS (optional, but useful if Testing4 calls this API)
9
- app.add_middleware(
10
- CORSMiddleware,
11
- allow_origins=["*"], # Adjust as needed for security
12
- allow_credentials=True,
13
- allow_methods=["*"],
14
- allow_headers=["*"],
15
- )
16
-
17
- # Load LoRA weights globally
18
- lora_weights = None
19
-
20
- @app.on_event("startup")
21
- async def startup_event():
22
- global lora_weights
23
  lora_path = "./lora_file.pth"
24
  if os.path.exists(lora_path):
25
  with open(lora_path, "rb") as f:
26
  # Base64 encode the LoRA weights for easy JSON transmission
27
- lora_weights = base64.b64encode(f.read()).decode("utf-8")
28
  print("LoRA weights loaded and preprocessed successfully.")
29
  else:
30
  print("LoRA file not found during startup.")
31
- raise HTTPException(status_code=500, detail="LoRA file not found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  @app.post("/modify-prompt")
34
  async def modify_prompt(prompt: str):
35
- global lora_weights
36
  if lora_weights is None:
37
  raise HTTPException(status_code=500, detail="LoRA weights not loaded.")
 
38
  # Combine prompt with preprocessed LoRA data
39
  extended_prompt = {
40
  "prompt": prompt,
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from contextlib import asynccontextmanager
4
  import base64
5
  import os
6
 
7
+ # Create the FastAPI app
8
+ @asynccontextmanager
9
+ async def lifespan(app: FastAPI):
10
+ # Load LoRA weights during app startup
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  lora_path = "./lora_file.pth"
12
  if os.path.exists(lora_path):
13
  with open(lora_path, "rb") as f:
14
  # Base64 encode the LoRA weights for easy JSON transmission
15
+ app.state.lora_weights = base64.b64encode(f.read()).decode("utf-8")
16
  print("LoRA weights loaded and preprocessed successfully.")
17
  else:
18
  print("LoRA file not found during startup.")
19
+ raise RuntimeError("LoRA file not found.")
20
+
21
+ # Yield control to the application
22
+ yield
23
+
24
+ # Perform any cleanup (if needed) here
25
+ print("Application shutting down.")
26
+
27
+ app = FastAPI(lifespan=lifespan)
28
+
29
+ # Middleware to handle CORS (optional, but useful for cross-origin requests)
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"], # Adjust for security
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
 
38
  @app.post("/modify-prompt")
39
  async def modify_prompt(prompt: str):
40
+ lora_weights = getattr(app.state, "lora_weights", None)
41
  if lora_weights is None:
42
  raise HTTPException(status_code=500, detail="LoRA weights not loaded.")
43
+
44
  # Combine prompt with preprocessed LoRA data
45
  extended_prompt = {
46
  "prompt": prompt,