Makhinur commited on
Commit
4f2123c
·
verified ·
1 Parent(s): 4bb7c9d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -24
main.py CHANGED
@@ -1,8 +1,8 @@
1
- from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import JSONResponse
3
  from gradio_client import Client, handle_file
4
- import shutil
5
  import base64
 
6
  import os
7
 
8
  app = FastAPI()
@@ -14,34 +14,24 @@ client = Client("Makhinur/Bringingoldphotoliveagain", hf_token=HF_TOKEN)
14
 
15
  @app.post("/upload/")
16
  async def upload_image(file: UploadFile = File(...)):
17
- if not file:
18
- raise HTTPException(status_code=400, detail="No file uploaded")
19
-
20
- # Save the uploaded file to a temporary location
21
- temp_file_path = f"temp_{file.filename}"
22
- with open(temp_file_path, "wb") as buffer:
23
- shutil.copyfileobj(file.file, buffer)
24
-
25
  try:
26
- # Use Gradio client to process the image
27
- result = client.predict(
28
- img=handle_file(temp_file_path),
 
 
 
 
29
  api_name="/predict"
30
  )
31
 
32
- # Encode the processed image as base64
33
- with open(result[0], "rb") as image_file:
34
- encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
35
-
36
- # Clean up the temporary file
37
- os.remove(temp_file_path)
38
 
39
- # Return the processed image as a base64 string
40
- return JSONResponse(content={"sketch_image_base64": f"data:image/png;base64,{encoded_image}"})
41
 
42
  except Exception as e:
43
- # Clean up the temporary file in case of an error
44
- if os.path.exists(temp_file_path):
45
- os.remove(temp_file_path)
46
  raise HTTPException(status_code=500, detail=str(e))
47
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from gradio_client import Client, handle_file
 
4
  import base64
5
+
6
  import os
7
 
8
  app = FastAPI()
 
14
 
15
  @app.post("/upload/")
16
  async def upload_image(file: UploadFile = File(...)):
 
 
 
 
 
 
 
 
17
  try:
18
+ # Save the uploaded file temporarily
19
+ with open(file.filename, "wb") as buffer:
20
+ buffer.write(await file.read())
21
+
22
+ # Use the Gradio client to process the image
23
+ result = gradio_client.predict(
24
+ img=handle_file(file.filename),
25
  api_name="/predict"
26
  )
27
 
28
+ # Read the output image and encode it in base64
29
+ with open(result[0], "rb") as result_file:
30
+ encoded_string = base64.b64encode(result_file.read()).decode('utf-8')
 
 
 
31
 
32
+ # Create a JSON response with the base64 encoded image
33
+ return JSONResponse(content={"sketch_image_base64": f"data:image/jpeg;base64,{encoded_string}"})
34
 
35
  except Exception as e:
 
 
 
36
  raise HTTPException(status_code=500, detail=str(e))
37