Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,14 +5,14 @@ import base64
|
|
5 |
import torch
|
6 |
import os
|
7 |
|
8 |
-
|
9 |
device = "cuda"
|
10 |
generator = torch.Generator(device=device)
|
11 |
|
12 |
seed = 496012807434005 #generator.seed()
|
13 |
generator = generator.manual_seed(seed)
|
14 |
-
HF_TOKEN = os.getenv('HF_TOKEN')
|
15 |
-
hf_writer =gr.HuggingFaceDatasetSaver(
|
16 |
|
17 |
|
18 |
def improve_image(img):
|
@@ -23,7 +23,7 @@ def improve_image(img):
|
|
23 |
resp_img = gr.processing_utils.decode_base64_to_image((resp_obj.json())['data'][0])
|
24 |
return resp_img
|
25 |
|
26 |
-
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",use_auth_token=
|
27 |
pipe = pipe.to("cuda")
|
28 |
|
29 |
def generate(celebrity, setting):
|
@@ -41,6 +41,6 @@ gr.Interface(
|
|
41 |
outputs = gr.Image(type='pill'),
|
42 |
allow_flagging="manual",
|
43 |
flagging_options = ['Incorrect movie poster','Incorrect Actor','Other Problem'],
|
44 |
-
|
45 |
flagging_dir='flagged_data'
|
46 |
).launch()
|
|
|
5 |
import torch
|
6 |
import os
|
7 |
|
8 |
+
auth_token = os.environ.get("auth_token")
|
9 |
device = "cuda"
|
10 |
generator = torch.Generator(device=device)
|
11 |
|
12 |
seed = 496012807434005 #generator.seed()
|
13 |
generator = generator.manual_seed(seed)
|
14 |
+
#HF_TOKEN = os.getenv('HF_TOKEN')
|
15 |
+
hf_writer =gr.HuggingFaceDatasetSaver(auth_token, "dst-movie-poster-demo")
|
16 |
|
17 |
|
18 |
def improve_image(img):
|
|
|
23 |
resp_img = gr.processing_utils.decode_base64_to_image((resp_obj.json())['data'][0])
|
24 |
return resp_img
|
25 |
|
26 |
+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",use_auth_token= auth_token)
|
27 |
pipe = pipe.to("cuda")
|
28 |
|
29 |
def generate(celebrity, setting):
|
|
|
41 |
outputs = gr.Image(type='pill'),
|
42 |
allow_flagging="manual",
|
43 |
flagging_options = ['Incorrect movie poster','Incorrect Actor','Other Problem'],
|
44 |
+
flagging_callback=hf_writer,
|
45 |
flagging_dir='flagged_data'
|
46 |
).launch()
|