Spaces:
Running
on
Zero
Running
on
Zero
Maitreya Patel
commited on
Commit
·
90ee823
1
Parent(s):
938774e
minor bug fixes
Browse files- app.py +4 -4
- requirements.txt +0 -1
app.py
CHANGED
@@ -43,7 +43,7 @@ class Ours:
|
|
43 |
CLIPTextModelWithProjection.from_pretrained(
|
44 |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
45 |
projection_dim=1280,
|
46 |
-
torch_dtype=torch.
|
47 |
)
|
48 |
.eval()
|
49 |
.requires_grad_(False)
|
@@ -55,7 +55,7 @@ class Ours:
|
|
55 |
|
56 |
prior = PriorTransformer.from_pretrained(
|
57 |
"ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior",
|
58 |
-
torch_dtype=torch.
|
59 |
)
|
60 |
|
61 |
self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
|
@@ -63,11 +63,11 @@ class Ours:
|
|
63 |
prior=prior,
|
64 |
text_encoder=text_encoder,
|
65 |
tokenizer=tokenizer,
|
66 |
-
torch_dtype=torch.
|
67 |
).to(device)
|
68 |
|
69 |
self.pipe = DiffusionPipeline.from_pretrained(
|
70 |
-
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.
|
71 |
).to(device)
|
72 |
|
73 |
def inference(self, text, negative_text, steps, guidance_scale):
|
|
|
43 |
CLIPTextModelWithProjection.from_pretrained(
|
44 |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
45 |
projection_dim=1280,
|
46 |
+
torch_dtype=torch.float32,
|
47 |
)
|
48 |
.eval()
|
49 |
.requires_grad_(False)
|
|
|
55 |
|
56 |
prior = PriorTransformer.from_pretrained(
|
57 |
"ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior",
|
58 |
+
torch_dtype=torch.float32,
|
59 |
)
|
60 |
|
61 |
self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
|
|
|
63 |
prior=prior,
|
64 |
text_encoder=text_encoder,
|
65 |
tokenizer=tokenizer,
|
66 |
+
torch_dtype=torch.float32,
|
67 |
).to(device)
|
68 |
|
69 |
self.pipe = DiffusionPipeline.from_pretrained(
|
70 |
+
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float32
|
71 |
).to(device)
|
72 |
|
73 |
def inference(self, text, negative_text, steps, guidance_scale):
|
requirements.txt
CHANGED
@@ -9,7 +9,6 @@ torch==2.0.0
|
|
9 |
torchvision==0.15.1
|
10 |
tqdm==4.66.1
|
11 |
transformers==4.34.1
|
12 |
-
gradio
|
13 |
jmespath
|
14 |
opencv-python
|
15 |
PyWavelet
|
|
|
9 |
torchvision==0.15.1
|
10 |
tqdm==4.66.1
|
11 |
transformers==4.34.1
|
|
|
12 |
jmespath
|
13 |
opencv-python
|
14 |
PyWavelet
|