tsqn commited on
Commit
8d37981
·
verified ·
1 Parent(s): c6ea03d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -34
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import spaces
 
2
  """
3
  Copyright NewGenAI
4
  Code can't be included in commercial app used for monetary gain. No derivative code allowed.
@@ -11,15 +12,14 @@ import time
11
  from datetime import datetime
12
  import os
13
 
14
- import torch
15
  from diffusers.utils import export_to_video
16
- from diffusers import LTXImageToVideoPipeline
17
  from transformers import T5EncoderModel, T5Tokenizer
18
  from pathlib import Path
19
  from datetime import datetime
20
  from huggingface_hub import hf_hub_download
21
 
22
- STATE_FILE = "LTX091_I2V_state.json"
23
  queue = []
24
 
25
  def load_state():
@@ -28,15 +28,16 @@ def load_state():
28
  return json.load(file)
29
  return {}
30
 
 
31
  def save_state(state):
32
  with open(STATE_FILE, "w") as file:
33
  json.dump(state, file)
34
 
 
35
  initial_state = load_state()
36
 
37
- def add_to_queue(image, prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed):
38
  task = {
39
- "image": image,
40
  "prompt": prompt,
41
  "negative_prompt": negative_prompt,
42
  "height": height,
@@ -59,7 +60,7 @@ def process_queue():
59
 
60
  for i, task in enumerate(queue):
61
  generate_video(**task)
62
- time.sleep(1)
63
 
64
  queue.clear()
65
  return "All tasks in the queue have been processed."
@@ -78,7 +79,6 @@ def save_ui_state(prompt, negative_prompt, height, width, num_frames, num_infere
78
  save_state(state)
79
  return "State saved!"
80
 
81
- # [Previous model loading code remains the same...]
82
  repo_id = "a-r-r-o-w/LTX-Video-0.9.1-diffusers"
83
  base_path = repo_id
84
  files_to_download = [
@@ -102,9 +102,11 @@ files_to_download = [
102
  os.makedirs(base_path, exist_ok=True)
103
  for file_path in files_to_download:
104
  try:
 
105
  full_dir = os.path.join(base_path, os.path.dirname(file_path))
106
  os.makedirs(full_dir, exist_ok=True)
107
 
 
108
  downloaded_path = hf_hub_download(
109
  repo_id=repo_id,
110
  filename=file_path,
@@ -117,20 +119,23 @@ for file_path in files_to_download:
117
  print(f"Error downloading {file_path}: {str(e)}")
118
  raise
119
 
 
120
  try:
 
121
  full_dir = os.path.join(base_path, os.path.dirname(file_path))
122
  os.makedirs(full_dir, exist_ok=True)
123
 
 
124
  downloaded_path = hf_hub_download(
125
  repo_id="Lightricks/LTX-Video",
126
  filename="ltx-video-2b-v0.9.1.safetensors",
127
  local_dir=repo_id,
128
  )
129
- print(f"Successfully downloaded: ltx-video-2b-v0.9.1.safetensors")
130
  except Exception as e:
131
  print(f"Error downloading 0.9.1 model: {str(e)}")
132
  raise
133
 
 
134
  single_file_url = repo_id+"/ltx-video-2b-v0.9.1.safetensors"
135
  text_encoder = T5EncoderModel.from_pretrained(
136
  repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
@@ -138,7 +143,7 @@ text_encoder = T5EncoderModel.from_pretrained(
138
  tokenizer = T5Tokenizer.from_pretrained(
139
  repo_id, subfolder="tokenizer", torch_dtype=torch.bfloat16
140
  )
141
- pipe = LTXImageToVideoPipeline.from_single_file(
142
  single_file_url,
143
  text_encoder=text_encoder,
144
  tokenizer=tokenizer,
@@ -146,42 +151,42 @@ pipe = LTXImageToVideoPipeline.from_single_file(
146
  )
147
  pipe.enable_model_cpu_offload()
148
 
149
- @spaces.GPU()
150
- def generate_video(image, prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed):
 
 
151
  if seed == 0:
152
  seed = random.randint(0, 999999)
153
-
154
- torch.cuda.synchronize()
155
- torch.cuda.empty_cache()
156
- with torch.inference_mode():
157
- video = pipe(
158
- image=image,
159
- prompt=prompt,
160
- negative_prompt=negative_prompt,
161
- width=width,
162
- height=height,
163
- num_frames=num_frames,
164
- num_inference_steps=num_inference_steps,
165
- generator=torch.Generator(device='cuda').manual_seed(seed),
166
- ).frames[0]
167
-
168
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
169
  filename = f"{prompt[:10]}_{timestamp}.mp4"
170
 
171
- os.makedirs("output_LTX091_i2v", exist_ok=True)
172
- output_path = f"./output_LTX091_i2v/{filename}"
 
173
  export_to_video(video, output_path, fps=fps)
174
 
175
  return output_path
176
 
 
177
  def randomize_seed():
178
  return random.randint(0, 999999)
179
 
180
  with gr.Blocks() as demo:
181
  with gr.Tabs():
182
  with gr.Tab("Generate Video"):
183
- with gr.Row():
184
- input_image = gr.Image(label="Input Image", type="pil")
185
  with gr.Row():
186
  prompt = gr.Textbox(label="Prompt", lines=3, value=initial_state.get("prompt", "A dramatic view of the pyramids at Giza during sunset."))
187
  negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value=initial_state.get("negative_prompt", "worst quality, blurry, distorted"))
@@ -203,7 +208,7 @@ with gr.Blocks() as demo:
203
  random_seed_button.click(lambda: random.randint(0, 999999), outputs=seed)
204
  generate_button.click(
205
  generate_video,
206
- inputs=[input_image, prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
207
  outputs=output_video
208
  )
209
  save_state_button.click(
@@ -213,8 +218,6 @@ with gr.Blocks() as demo:
213
  )
214
 
215
  with gr.Tab("Batch Processing"):
216
- with gr.Row():
217
- batch_input_image = gr.Image(label="Input Image", type="pil")
218
  with gr.Row():
219
  batch_prompt = gr.Textbox(label="Prompt", lines=3, value="A batch of videos depicting different landscapes.")
220
  batch_negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="low quality, inconsistent, jittery")
@@ -238,7 +241,7 @@ with gr.Blocks() as demo:
238
  random_seed_batch_button.click(lambda: random.randint(0, 999999), outputs=batch_seed)
239
  add_to_queue_button.click(
240
  add_to_queue,
241
- inputs=[batch_input_image, batch_prompt, batch_negative_prompt, batch_height, batch_width, batch_num_frames, batch_num_inference_steps, batch_fps, batch_seed],
242
  outputs=queue_status
243
  )
244
  clear_queue_button.click(clear_queue, outputs=queue_status)
 
1
  import spaces
2
+
3
  """
4
  Copyright NewGenAI
5
  Code can't be included in commercial app used for monetary gain. No derivative code allowed.
 
12
  from datetime import datetime
13
  import os
14
 
 
15
  from diffusers.utils import export_to_video
16
+ from diffusers import LTXPipeline
17
  from transformers import T5EncoderModel, T5Tokenizer
18
  from pathlib import Path
19
  from datetime import datetime
20
  from huggingface_hub import hf_hub_download
21
 
22
+ STATE_FILE = "LTX091_state.json"
23
  queue = []
24
 
25
  def load_state():
 
28
  return json.load(file)
29
  return {}
30
 
31
+ # Function to save the current state
32
  def save_state(state):
33
  with open(STATE_FILE, "w") as file:
34
  json.dump(state, file)
35
 
36
+ # Load initial state
37
  initial_state = load_state()
38
 
39
+ def add_to_queue(prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed):
40
  task = {
 
41
  "prompt": prompt,
42
  "negative_prompt": negative_prompt,
43
  "height": height,
 
60
 
61
  for i, task in enumerate(queue):
62
  generate_video(**task)
63
+ time.sleep(1) # Simulate processing time
64
 
65
  queue.clear()
66
  return "All tasks in the queue have been processed."
 
79
  save_state(state)
80
  return "State saved!"
81
 
 
82
  repo_id = "a-r-r-o-w/LTX-Video-0.9.1-diffusers"
83
  base_path = repo_id
84
  files_to_download = [
 
102
  os.makedirs(base_path, exist_ok=True)
103
  for file_path in files_to_download:
104
  try:
105
+ # Create the full directory path for this file
106
  full_dir = os.path.join(base_path, os.path.dirname(file_path))
107
  os.makedirs(full_dir, exist_ok=True)
108
 
109
+ # Download the file
110
  downloaded_path = hf_hub_download(
111
  repo_id=repo_id,
112
  filename=file_path,
 
119
  print(f"Error downloading {file_path}: {str(e)}")
120
  raise
121
 
122
+ # Download model from different repo
123
  try:
124
+ # Create the full directory path for this file
125
  full_dir = os.path.join(base_path, os.path.dirname(file_path))
126
  os.makedirs(full_dir, exist_ok=True)
127
 
128
+ # Download the file
129
  downloaded_path = hf_hub_download(
130
  repo_id="Lightricks/LTX-Video",
131
  filename="ltx-video-2b-v0.9.1.safetensors",
132
  local_dir=repo_id,
133
  )
 
134
  except Exception as e:
135
  print(f"Error downloading 0.9.1 model: {str(e)}")
136
  raise
137
 
138
+
139
  single_file_url = repo_id+"/ltx-video-2b-v0.9.1.safetensors"
140
  text_encoder = T5EncoderModel.from_pretrained(
141
  repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
 
143
  tokenizer = T5Tokenizer.from_pretrained(
144
  repo_id, subfolder="tokenizer", torch_dtype=torch.bfloat16
145
  )
146
+ pipe = LTXPipeline.from_single_file(
147
  single_file_url,
148
  text_encoder=text_encoder,
149
  tokenizer=tokenizer,
 
151
  )
152
  pipe.enable_model_cpu_offload()
153
 
154
+ @spaces.GPU(duration=120)
155
+ @torch.inference_mode()
156
+ def generate_video(prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed):
157
+ # Randomize seed if seed is 0
158
  if seed == 0:
159
  seed = random.randint(0, 999999)
160
+
161
+ # Generating the video <Does not support seed :( >
162
+ video = pipe(
163
+ prompt=prompt,
164
+ negative_prompt=negative_prompt,
165
+ width=width,
166
+ height=height,
167
+ num_frames=num_frames,
168
+ num_inference_steps=num_inference_steps,
169
+ generator=torch.Generator(device='cuda').manual_seed(seed),
170
+ ).frames[0]
171
+
172
+ # Create output filename based on prompt and timestamp
 
 
173
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
174
  filename = f"{prompt[:10]}_{timestamp}.mp4"
175
 
176
+ # Save the video to the output folder
177
+ os.makedirs("output_LTX091", exist_ok=True)
178
+ output_path = f"./output_LTX091/{filename}"
179
  export_to_video(video, output_path, fps=fps)
180
 
181
  return output_path
182
 
183
+ # Gradio UI setup
184
  def randomize_seed():
185
  return random.randint(0, 999999)
186
 
187
  with gr.Blocks() as demo:
188
  with gr.Tabs():
189
  with gr.Tab("Generate Video"):
 
 
190
  with gr.Row():
191
  prompt = gr.Textbox(label="Prompt", lines=3, value=initial_state.get("prompt", "A dramatic view of the pyramids at Giza during sunset."))
192
  negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value=initial_state.get("negative_prompt", "worst quality, blurry, distorted"))
 
208
  random_seed_button.click(lambda: random.randint(0, 999999), outputs=seed)
209
  generate_button.click(
210
  generate_video,
211
+ inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
212
  outputs=output_video
213
  )
214
  save_state_button.click(
 
218
  )
219
 
220
  with gr.Tab("Batch Processing"):
 
 
221
  with gr.Row():
222
  batch_prompt = gr.Textbox(label="Prompt", lines=3, value="A batch of videos depicting different landscapes.")
223
  batch_negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="low quality, inconsistent, jittery")
 
241
  random_seed_batch_button.click(lambda: random.randint(0, 999999), outputs=batch_seed)
242
  add_to_queue_button.click(
243
  add_to_queue,
244
+ inputs=[batch_prompt, batch_negative_prompt, batch_height, batch_width, batch_num_frames, batch_num_inference_steps, batch_fps, batch_seed],
245
  outputs=queue_status
246
  )
247
  clear_queue_button.click(clear_queue, outputs=queue_status)