meepmoo commited on
Commit
ddc765d
·
verified ·
1 Parent(s): 1999e9d

Update worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker_runpod.py +8 -2
worker_runpod.py CHANGED
@@ -110,6 +110,8 @@ else:
110
  def generate(input):
111
  values = input["input"]
112
  prompt = values["prompt"]
 
 
113
  negative_prompt = values.get("negative_prompt", "blurry, blurred, blurry face")
114
  guidance_scale = values.get("guidance_scale", 6.0)
115
  seed = values.get("seed", 42)
@@ -123,10 +125,12 @@ def generate(input):
123
  partial_video_length = values.get("partial_video_length", None)
124
  overlap_video_length = values.get("overlap_video_length", 4)
125
  validation_image_start = values.get("validation_image_start", "asset/1.png")
 
126
  downloaded_image_path = download_image(validation_image_start)
127
  validation_image_end = values.get("validation_image_end", None)
128
 
129
  generator = torch.Generator(device="cuda").manual_seed(seed)
 
130
  if lora_path is not None:
131
  pipeline = merge_lora(pipeline, lora_path, lora_weight)
132
 
@@ -136,6 +140,8 @@ def generate(input):
136
  closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
137
  height, width = [int(x / 16) * 16 for x in closest_size]
138
  sample_size = [height, width]
 
 
139
  video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
140
  input_video, input_video_mask, clip_image = get_image_to_video_latent(downloaded_image_path, validation_image_end, video_length=video_length, sample_size=sample_size)
141
 
@@ -150,7 +156,7 @@ def generate(input):
150
  video_path = os.path.join(save_path, f"{prefix}.mp4")
151
  save_videos_grid(sample, video_path, fps=fps)
152
 
153
-
154
  hf_api = HfApi()
155
  repo_id = "meepmoo/h4h4jejdf" # Set your HF repo
156
  hf_api.upload_file(
@@ -161,7 +167,7 @@ def generate(input):
161
  repo_type="model"
162
  )
163
 
164
-
165
  result_url = f"https://huggingface.co/{repo_id}/blob/main/{prefix}.mp4"
166
  result_url = ""
167
  job_id = values.get("job_id", "default-job-id") # For RunPod job tracking
 
110
  def generate(input):
111
  values = input["input"]
112
  prompt = values["prompt"]
113
+ print("starting Generate function")
114
+ print(prompt)
115
  negative_prompt = values.get("negative_prompt", "blurry, blurred, blurry face")
116
  guidance_scale = values.get("guidance_scale", 6.0)
117
  seed = values.get("seed", 42)
 
125
  partial_video_length = values.get("partial_video_length", None)
126
  overlap_video_length = values.get("overlap_video_length", 4)
127
  validation_image_start = values.get("validation_image_start", "asset/1.png")
128
+ print(validation_image_start)
129
  downloaded_image_path = download_image(validation_image_start)
130
  validation_image_end = values.get("validation_image_end", None)
131
 
132
  generator = torch.Generator(device="cuda").manual_seed(seed)
133
+ print("Generator started")
134
  if lora_path is not None:
135
  pipeline = merge_lora(pipeline, lora_path, lora_weight)
136
 
 
140
  closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
141
  height, width = [int(x / 16) * 16 for x in closest_size]
142
  sample_size = [height, width]
143
+ print("Getting closest ratio")
144
+ print(closest_ratio)
145
  video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
146
  input_video, input_video_mask, clip_image = get_image_to_video_latent(downloaded_image_path, validation_image_end, video_length=video_length, sample_size=sample_size)
147
 
 
156
  video_path = os.path.join(save_path, f"{prefix}.mp4")
157
  save_videos_grid(sample, video_path, fps=fps)
158
 
159
+ print("Video saved to grid, uploading to huggingface")
160
  hf_api = HfApi()
161
  repo_id = "meepmoo/h4h4jejdf" # Set your HF repo
162
  hf_api.upload_file(
 
167
  repo_type="model"
168
  )
169
 
170
+ print("Video uploaded to huggingface returing output")
171
  result_url = f"https://huggingface.co/{repo_id}/blob/main/{prefix}.mp4"
172
  result_url = ""
173
  job_id = values.get("job_id", "default-job-id") # For RunPod job tracking