Hatman commited on
Commit
566b8e0
·
verified ·
1 Parent(s): 1a187f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -25
app.py CHANGED
@@ -13,30 +13,8 @@ from torchvision import transforms
13
 
14
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
17
- "h94/IP-Adapter",
18
- subfolder="models/image_encoder",
19
- torch_dtype=dtype
20
- )
21
- pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",image_encoder=image_encoder, torch_dtype=dtype).to(device)
22
- pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
23
-
24
- def prepare_image(image_path_or_url):
25
- # Load the image
26
- image = load_image(image_path_or_url)
27
-
28
- # Convert to tensor and move to correct device and dtype
29
- transform = transforms.Compose([
30
- transforms.ToTensor(),
31
- transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC)
32
- ])
33
- image_tensor = transform(image).unsqueeze(0) # Add batch dimension
34
- return image_tensor.to(device=device, dtype=dtype)
35
-
36
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
37
- if randomize_seed:
38
- seed = random.randint(0, 2000)
39
- return seed
40
 
41
  @spaces.GPU(enable_queue=True)
42
  def create_image(image_pil,
@@ -65,7 +43,7 @@ def create_image(image_pil,
65
  }
66
  pipeline.set_ip_adapter_scale(scale)
67
 
68
- style_image = prepare_image(image_pil)
69
  generator = torch.Generator(device=device).manual_seed(randomize_seed_fn(seed, False))
70
 
71
  image = pipeline(
 
13
 
14
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
+ pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype).to(device)
17
+ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  @spaces.GPU(enable_queue=True)
20
  def create_image(image_pil,
 
43
  }
44
  pipeline.set_ip_adapter_scale(scale)
45
 
46
+ style_image = load_image(image_pil)
47
  generator = torch.Generator(device=device).manual_seed(randomize_seed_fn(seed, False))
48
 
49
  image = pipeline(