akameswa commited on
Commit
95a9773
·
verified ·
1 Parent(s): 6854c35

Update src/util/base.py

Browse files
Files changed (1) hide show
  1. src/util/base.py +6 -5
src/util/base.py CHANGED
@@ -2,6 +2,7 @@ import io
2
  import os
3
  import torch
4
  import zipfile
 
5
  import numpy as np
6
  import gradio as gr
7
  from PIL import Image
@@ -10,7 +11,7 @@ from src.util.params import *
10
  from src.util.clip_config import *
11
  import matplotlib.pyplot as plt
12
 
13
-
14
  def get_text_embeddings(
15
  prompt,
16
  tokenizer=tokenizer,
@@ -42,7 +43,7 @@ def get_text_embeddings(
42
 
43
  return text_embeddings
44
 
45
-
46
  def generate_latents(
47
  seed,
48
  height=imageHeight,
@@ -60,7 +61,7 @@ def generate_latents(
60
 
61
  return latents
62
 
63
-
64
  def generate_modified_latents(
65
  poke,
66
  seed,
@@ -98,7 +99,7 @@ def convert_to_pil_image(image):
98
  pil_images = [Image.fromarray(image) for image in images]
99
  return pil_images[0]
100
 
101
-
102
  def generate_images(
103
  latents,
104
  text_embeddings,
@@ -147,7 +148,7 @@ def generate_images(
147
 
148
  return images
149
 
150
-
151
  def get_word_embeddings(
152
  prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device
153
  ):
 
2
  import os
3
  import torch
4
  import zipfile
5
+ import spaces
6
  import numpy as np
7
  import gradio as gr
8
  from PIL import Image
 
11
  from src.util.clip_config import *
12
  import matplotlib.pyplot as plt
13
 
14
+ @spaces.GPU(enable_queue=True)
15
  def get_text_embeddings(
16
  prompt,
17
  tokenizer=tokenizer,
 
43
 
44
  return text_embeddings
45
 
46
+ @spaces.GPU(enable_queue=True)
47
  def generate_latents(
48
  seed,
49
  height=imageHeight,
 
61
 
62
  return latents
63
 
64
+ @spaces.GPU(enable_queue=True)
65
  def generate_modified_latents(
66
  poke,
67
  seed,
 
99
  pil_images = [Image.fromarray(image) for image in images]
100
  return pil_images[0]
101
 
102
+ @spaces.GPU(enable_queue=True)
103
  def generate_images(
104
  latents,
105
  text_embeddings,
 
148
 
149
  return images
150
 
151
+ @spaces.GPU(enable_queue=True)
152
  def get_word_embeddings(
153
  prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device
154
  ):