mpatel57 commited on
Commit
a822226
·
verified ·
1 Parent(s): ba1943f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -27,7 +27,11 @@ from src.pipelines.pipeline_kandinsky_subject_prior import KandinskyPriorPipelin
27
  from diffusers import DiffusionPipeline
28
  from PIL import Image
29
 
30
- __device__ = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
31
 
32
  class Model:
33
  def __init__(self):
@@ -37,7 +41,7 @@ class Model:
37
  CLIPTextModelWithProjection.from_pretrained(
38
  "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
39
  projection_dim=1280,
40
- torch_dtype=torch.float16,
41
  )
42
  .eval()
43
  .requires_grad_(False)
@@ -49,17 +53,17 @@ class Model:
49
 
50
  prior = PriorTransformer.from_pretrained(
51
  "ECLIPSE-Community/Lambda-ECLIPSE-Prior-v1.0",
52
- torch_dtype=torch.float16,
53
  )
54
 
55
  self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
56
  "kandinsky-community/kandinsky-2-2-prior",
57
  prior=prior,
58
- torch_dtype=torch.float16,
59
  ).to(self.device)
60
 
61
  self.pipe = DiffusionPipeline.from_pretrained(
62
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
63
  ).to(self.device)
64
 
65
  def inference(self, raw_data):
 
27
  from diffusers import DiffusionPipeline
28
  from PIL import Image
29
 
30
+ __device__ = "cpu"
31
+ __dtype__ = torch.float32
32
+ if torch.cuda.is_available():
33
+ __device__ = "cuda"
34
+ __dtype__ = torch.float16
35
 
36
  class Model:
37
  def __init__(self):
 
41
  CLIPTextModelWithProjection.from_pretrained(
42
  "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
43
  projection_dim=1280,
44
+ torch_dtype=__dtype__,
45
  )
46
  .eval()
47
  .requires_grad_(False)
 
53
 
54
  prior = PriorTransformer.from_pretrained(
55
  "ECLIPSE-Community/Lambda-ECLIPSE-Prior-v1.0",
56
+ torch_dtype=__dtype__,
57
  )
58
 
59
  self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
60
  "kandinsky-community/kandinsky-2-2-prior",
61
  prior=prior,
62
+ torch_dtype=__dtype__,
63
  ).to(self.device)
64
 
65
  self.pipe = DiffusionPipeline.from_pretrained(
66
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
67
  ).to(self.device)
68
 
69
  def inference(self, raw_data):