Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- hf_demo.py +5 -3
hf_demo.py
CHANGED
@@ -10,7 +10,9 @@ from PIL import Image
|
|
10 |
|
11 |
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
-
|
|
|
|
|
14 |
|
15 |
from inference import get_lora_network, inference, get_validation_dataloader
|
16 |
lora_map = {
|
@@ -33,7 +35,7 @@ lora_map = {
|
|
33 |
"Henri Matisse": "henri-matisse_subset1",
|
34 |
"Joan Miro": "joan-miro_subset2",
|
35 |
}
|
36 |
-
|
37 |
def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
|
38 |
adapter_path = lora_map[adapter_choice]
|
39 |
if adapter_path not in [None, "None"]:
|
@@ -48,7 +50,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
|
|
48 |
start_noise=-1, show=False, style_prompt="sks art", no_load=True,
|
49 |
from_scratch=True, device=device)[0][1.0]
|
50 |
return pred_images
|
51 |
-
|
52 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
53 |
infer_loader = get_validation_dataloader(prompts, image)
|
54 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|
|
|
10 |
|
11 |
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
14 |
+
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
15 |
+
dtype=dtype).to(device)
|
16 |
|
17 |
from inference import get_lora_network, inference, get_validation_dataloader
|
18 |
lora_map = {
|
|
|
35 |
"Henri Matisse": "henri-matisse_subset1",
|
36 |
"Joan Miro": "joan-miro_subset2",
|
37 |
}
|
38 |
+
@spaces.GPU
|
39 |
def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
|
40 |
adapter_path = lora_map[adapter_choice]
|
41 |
if adapter_path not in [None, "None"]:
|
|
|
50 |
start_noise=-1, show=False, style_prompt="sks art", no_load=True,
|
51 |
from_scratch=True, device=device)[0][1.0]
|
52 |
return pred_images
|
53 |
+
@spaces.GPU
|
54 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
55 |
infer_loader = get_validation_dataloader(prompts, image)
|
56 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|