Ryukijano commited on
Commit
eeb1f81
·
verified ·
1 Parent(s): e0dcebd

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +4 -1
custom_pipeline.py CHANGED
@@ -19,7 +19,10 @@ BATCH_SIZE = 4 # Optimal batch size for A100
19
 
20
  @torch.jit.script
21
  def calculate_timestep_shift(image_seq_len: int) -> float:
22
- """Optimized timestep shift calculation using TorchScript"""
 
 
 
23
  m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
24
  b = BASE_SHIFT - m * BASE_SEQ_LEN
25
  return image_seq_len * m + b
 
19
 
20
  @torch.jit.script
21
  def calculate_timestep_shift(image_seq_len: int) -> float:
22
+ BASE_SEQ_LEN = 256
23
+ MAX_SEQ_LEN = 4096
24
+ BASE_SHIFT = 0.5
25
+ MAX_SHIFT = 1.2
26
  m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
27
  b = BASE_SHIFT - m * BASE_SEQ_LEN
28
  return image_seq_len * m + b