rhfeiyang commited on
Commit
feb4f05
·
1 Parent(s): efbbb9d
Files changed (1) hide show
  1. hf_demo.py +5 -3
hf_demo.py CHANGED
@@ -10,7 +10,9 @@ from PIL import Image
10
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",).to(device)
 
 
14
 
15
  from inference import get_lora_network, inference, get_validation_dataloader
16
  lora_map = {
@@ -33,7 +35,7 @@ lora_map = {
33
  "Henri Matisse": "henri-matisse_subset1",
34
  "Joan Miro": "joan-miro_subset2",
35
  }
36
-
37
  def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
38
  adapter_path = lora_map[adapter_choice]
39
  if adapter_path not in [None, "None"]:
@@ -48,7 +50,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
48
  start_noise=-1, show=False, style_prompt="sks art", no_load=True,
49
  from_scratch=True, device=device)[0][1.0]
50
  return pred_images
51
-
52
  def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
53
  infer_loader = get_validation_dataloader(prompts, image)
54
  network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
 
10
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
14
+ pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
15
+ dtype=dtype).to(device)
16
 
17
  from inference import get_lora_network, inference, get_validation_dataloader
18
  lora_map = {
 
35
  "Henri Matisse": "henri-matisse_subset1",
36
  "Joan Miro": "joan-miro_subset2",
37
  }
38
+ @spaces.GPU
39
  def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
40
  adapter_path = lora_map[adapter_choice]
41
  if adapter_path not in [None, "None"]:
 
50
  start_noise=-1, show=False, style_prompt="sks art", no_load=True,
51
  from_scratch=True, device=device)[0][1.0]
52
  return pred_images
53
+ @spaces.GPU
54
  def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
55
  infer_loader = get_validation_dataloader(prompts, image)
56
  network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]