hshetty commited on
Commit
e518a26
·
1 Parent(s): e6b42d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -5,14 +5,14 @@ import base64
5
  import torch
6
  import os
7
 
8
-
9
  device = "cuda"
10
  generator = torch.Generator(device=device)
11
 
12
  seed = 496012807434005 #generator.seed()
13
  generator = generator.manual_seed(seed)
14
- HF_TOKEN = os.getenv('HF_TOKEN')
15
- hf_writer =gr.HuggingFaceDatasetSaver(HF_TOKEN, "dst-movie-poster-demo")
16
 
17
 
18
  def improve_image(img):
@@ -23,7 +23,7 @@ def improve_image(img):
23
  resp_img = gr.processing_utils.decode_base64_to_image((resp_obj.json())['data'][0])
24
  return resp_img
25
 
26
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",use_auth_token= HF_TOKEN)
27
  pipe = pipe.to("cuda")
28
 
29
  def generate(celebrity, setting):
@@ -41,6 +41,6 @@ gr.Interface(
41
  outputs = gr.Image(type='pill'),
42
  allow_flagging="manual",
43
  flagging_options = ['Incorrect movie poster','Incorrect Actor','Other Problem'],
44
- #flagging_callback=hf_writer,
45
  flagging_dir='flagged_data'
46
  ).launch()
 
5
  import torch
6
  import os
7
 
8
+ auth_token = os.environ.get("auth_token")
9
  device = "cuda"
10
  generator = torch.Generator(device=device)
11
 
12
  seed = 496012807434005 #generator.seed()
13
  generator = generator.manual_seed(seed)
14
+ #HF_TOKEN = os.getenv('HF_TOKEN')
15
+ hf_writer =gr.HuggingFaceDatasetSaver(auth_token, "dst-movie-poster-demo")
16
 
17
 
18
  def improve_image(img):
 
23
  resp_img = gr.processing_utils.decode_base64_to_image((resp_obj.json())['data'][0])
24
  return resp_img
25
 
26
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",use_auth_token= auth_token)
27
  pipe = pipe.to("cuda")
28
 
29
  def generate(celebrity, setting):
 
41
  outputs = gr.Image(type='pill'),
42
  allow_flagging="manual",
43
  flagging_options = ['Incorrect movie poster','Incorrect Actor','Other Problem'],
44
+ flagging_callback=hf_writer,
45
  flagging_dir='flagged_data'
46
  ).launch()