Hatman commited on
Commit
f37caca
·
verified ·
1 Parent(s): ea7ea18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -15,6 +15,18 @@ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
  pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, device_map="balanced")
16
  pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
19
  if randomize_seed:
20
  seed = random.randint(0, 2000)
@@ -47,7 +59,7 @@ def create_image(image_pil,
47
  }
48
  pipeline.set_ip_adapter_scale(scale)
49
 
50
- style_image = load_image(image_pil)
51
  generator = torch.Generator(device=device).manual_seed(randomize_seed_fn(seed, False))
52
 
53
  image = pipeline(
 
15
  pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, device_map="balanced")
16
  pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
17
 
18
+ def prepare_image(image_path_or_url):
19
+ # Load the image
20
+ image = load_image(image_path_or_url)
21
+
22
+ # Convert to tensor and move to correct device and dtype
23
+ transform = transforms.Compose([
24
+ transforms.ToTensor(),
25
+ transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.LANCZOS)
26
+ ])
27
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
28
+ return image_tensor.to(device=device, dtype=dtype)
29
+
30
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
31
  if randomize_seed:
32
  seed = random.randint(0, 2000)
 
59
  }
60
  pipeline.set_ip_adapter_scale(scale)
61
 
62
+ style_image = prepare_image(image_pil)
63
  generator = torch.Generator(device=device).manual_seed(randomize_seed_fn(seed, False))
64
 
65
  image = pipeline(