prithivMLmods commited on
Commit
ec8d7fa
·
verified ·
1 Parent(s): 40afddd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -170
app.py CHANGED
@@ -1,181 +1,139 @@
1
  import os
2
- import random
3
  import uuid
 
 
 
 
4
  import gradio as gr
 
 
5
  import numpy as np
6
  from PIL import Image
7
- import torch
8
- from diffusers import DiffusionPipeline
9
- import spaces
10
-
11
- # Setup
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model_repo_id = "stabilityai/stable-diffusion-3.5-large-turbo"
14
- torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
-
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
18
- pipe.load_lora_weights("strangerzonehf/SD3.5-Turbo-Portrait-LoRA", weight_name="SD3.5-Turbo-Portrait.safetensors")
19
- pipe.fuse_lora(lora_scale=1.0)
20
-
21
- MAX_SEED = np.iinfo(np.int32).max
22
- MAX_IMAGE_SIZE = 1024
23
-
24
- # Style presets
25
- style_list = [
26
- {
27
- "name": "3840 x 2160",
28
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
29
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
30
- },
31
- {
32
- "name": "2560 x 1440",
33
- "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
34
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
35
- },
36
- {
37
- "name": "HD+",
38
- "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
39
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
40
- },
41
- {
42
- "name": "Style Zero",
43
- "prompt": "{prompt}",
44
- "negative_prompt": "",
45
- },
46
- ]
47
-
48
- STYLE_NAMES = [s["name"] for s in style_list]
49
-
50
- def randomize_seed_fn(seed, randomize):
51
- return random.randint(0, MAX_SEED) if randomize else seed
52
-
53
- def save_image(img):
54
- filename = str(uuid.uuid4()) + ".png"
55
- img.save(filename)
56
- return filename
57
 
58
  @spaces.GPU
59
- def generate_images(
60
- prompt,
61
- style,
62
- negative_prompt,
63
- seed,
64
- randomize_seed,
65
- width,
66
- height,
67
- guidance_scale,
68
- num_inference_steps,
69
- num_images,
70
- progress=gr.Progress(track_tqdm=True)
71
- ):
72
- seed = randomize_seed_fn(seed, randomize_seed)
73
- generator = torch.Generator(device=device).manual_seed(seed)
74
-
75
- selected_style = next(s for s in style_list if s["name"] == style)
76
- styled_prompt = selected_style["prompt"].format(prompt=prompt)
77
- styled_negative_prompt = selected_style["negative_prompt"] if not negative_prompt else negative_prompt
78
-
79
- images = []
80
- for _ in range(num_images):
81
- image = pipe(
82
- prompt=styled_prompt,
83
- negative_prompt=styled_negative_prompt,
84
- width=width,
85
- height=height,
86
- guidance_scale=guidance_scale,
87
- num_inference_steps=num_inference_steps,
88
- generator=generator
89
- ).images[0]
90
- images.append(image)
91
-
92
- image_paths = [save_image(img) for img in images]
93
- return image_paths, seed
94
-
95
- # CSS & Interface
96
- css = '''
97
- .gradio-container {
98
- max-width: 150%;
99
- margin: 0 auto;
100
- }
101
- h1 { text-align: center; }
102
- footer { visibility: hidden; }
103
- '''
104
-
105
- examples = [
106
- "portrait photo of a futuristic astronaut",
107
- "macro shot of a water droplet on a leaf",
108
- "hyper-realistic food photography of a burger",
109
- "cyberpunk city at night, rain, neon lights",
110
- "ultra detailed fantasy landscape with dragons",
111
- ]
112
-
113
- with gr.Blocks(css=css, theme="YTheme/GMaterial") as demo:
114
- gr.Markdown("## SD3.5 Turbo Portrait")
115
-
116
- with gr.Row():
117
- with gr.Column(scale=1):
118
- with gr.Row():
119
- prompt = gr.Text(
120
- show_label=False,
121
- max_lines=1,
122
- placeholder="Enter your prompt",
123
- container=False,
124
- )
125
- run_button = gr.Button("Run", scale=0, variant="primary")
126
-
127
- result_gallery = gr.Gallery(show_label=False, format="png", columns=2, object_fit="contain")
128
-
129
- with gr.Accordion("Advanced Settings", open=False):
130
- num_images = gr.Slider(
131
- label="Number of Images",
132
- minimum=1,
133
- maximum=10,
134
- value=5,
135
- step=1,
136
- )
137
- style = gr.Dropdown(label="Select Style", choices=STYLE_NAMES, value=STYLE_NAMES[0])
138
-
139
- negative_prompt = gr.Text(
140
- label="Negative Prompt",
141
- max_lines=4,
142
- lines=3,
143
- value="cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly"
144
- )
145
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
146
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
147
- with gr.Row():
148
- width = gr.Slider(label="Width", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
149
- height = gr.Slider(label="Height", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
150
- with gr.Row():
151
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.5, value=0.0)
152
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=30, step=1, value=4)
153
-
154
- with gr.Column(scale=1):
155
- gr.Examples(
156
- examples=examples,
157
- inputs=prompt,
158
- cache_examples=False,
159
- )
160
-
161
- gr.on(
162
- triggers=[prompt.submit, run_button.click],
163
- fn=generate_images,
164
- inputs=[
165
- prompt,
166
- style,
167
- negative_prompt,
168
- seed,
169
- randomize_seed,
170
- width,
171
- height,
172
- guidance_scale,
173
- num_inference_steps,
174
- num_images
175
  ],
176
- outputs=[result_gallery, seed],
177
- api_name="generate"
 
 
 
 
178
  )
 
179
 
180
  if __name__ == "__main__":
181
- demo.queue(max_size=40).launch(ssr_mode=False)
 
1
  import os
 
2
  import uuid
3
+ import time
4
+ import asyncio
5
+ from threading import Thread
6
+
7
  import gradio as gr
8
+ import spaces
9
+ import torch
10
  import numpy as np
11
  from PIL import Image
12
+ import cv2
13
+ import edge_tts
14
+
15
+ from transformers import (
16
+ Qwen2_5_VLForConditionalGeneration,
17
+ AutoProcessor,
18
+ TextIteratorStreamer
19
+ )
20
+
21
+ # Constants
22
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Load multimodal processor and model (Callisto OCR3)
26
+ MODEL_ID = "nvidia/Cosmos-Reason1-7B"
27
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
28
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
29
+ MODEL_ID,
30
+ trust_remote_code=True,
31
+ torch_dtype=torch.float16
32
+ ).to(device).eval()
33
+
34
+
35
+ def downsample_video(video_path: str, num_frames: int = 10):
36
+ vidcap = cv2.VideoCapture(video_path)
37
+ total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
38
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
39
+ idxs = np.linspace(0, total - 1, num_frames, dtype=int)
40
+ frames = []
41
+ for i in idxs:
42
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
43
+ ok, img = vidcap.read()
44
+ if ok:
45
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
46
+ pil = Image.fromarray(rgb)
47
+ timestamp = round(i / fps, 2)
48
+ frames.append((pil, timestamp))
49
+ vidcap.release()
50
+ return frames
51
+
52
+
53
+ def progress_bar_html(label: str) -> str:
54
+ return f'''<div style="display:flex; align-items:center;">
55
+ <span style="margin-right:10px; font-size:14px;">{label}</span>
56
+ <div style="width:110px; height:5px; background:#B0E0E6; border-radius:2px; overflow:hidden;">
57
+ <div style="width:100%; height:100%; background:#00FFFF; animation:load 1.5s linear infinite;"></div>
58
+ </div>
59
+ </div>
60
+ <style>@keyframes load{{0%{{transform:translateX(-100%)}}100%{{transform:translateX(100%)}}}}</style>'''
 
61
 
62
  @spaces.GPU
63
+ def generate(prompt: str, files: list[str] = None):
64
+ files = files or []
65
+ # Determine mode
66
+ is_video = any(f.lower().endswith(('.mp4', '.avi', '.mov')) for f in files)
67
+ is_image = any(f.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')) for f in files)
68
+
69
+ if is_video:
70
+ yield progress_bar_html("Processing video with cosmos-reason1")
71
+ video = files[0]
72
+ frames = downsample_video(video)
73
+ # Build messages
74
+ messages = [
75
+ {"role": "system", "content": [{"type":"text","text":"You are a helpful assistant."}]},
76
+ {"role": "user", "content": [{"type":"text","text": prompt}]}
77
+ ]
78
+ for img, ts in frames:
79
+ path = f"frame_{uuid.uuid4().hex}.png"
80
+ img.save(path)
81
+ messages[1]["content"].extend([
82
+ {"type":"text","text": f"Frame {ts}:"},
83
+ {"type":"image","url": path}
84
+ ])
85
+ inputs = processor.apply_chat_template(
86
+ messages, tokenize=True, add_generation_prompt=True,
87
+ return_dict=True, return_tensors="pt",
88
+ truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
89
+ ).to(device)
90
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
91
+ Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start()
92
+ buffer = ""
93
+ for txt in streamer:
94
+ buffer += txt.replace("<|im_end|>", "")
95
+ time.sleep(0.01)
96
+ yield buffer
97
+ return
98
+
99
+ if is_image:
100
+ yield progress_bar_html("Processing image with cosmos-reason1")
101
+ imgs = [Image.open(f) for f in files]
102
+ messages = [
103
+ {"role":"user","content":[*[{"type":"image","image":i} for i in imgs],{"type":"text","text":prompt}]}]
104
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105
+ inputs = processor(
106
+ text=[prompt_full], images=imgs,
107
+ return_tensors="pt", padding=True,
108
+ truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
109
+ ).to(device)
110
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
111
+ Thread(target=model.generate, kwargs={**inputs, "streamer": streamer}).start()
112
+ out = ""
113
+ for txt in streamer:
114
+ out += txt.replace("<|im_end|>", "")
115
+ time.sleep(0.01)
116
+ yield out
117
+ return
118
+
119
+ # No valid media
120
+ yield "Please upload at least one image or a video for inference."
121
+
122
+
123
+ def main():
124
+ demo = gr.ChatInterface(
125
+ fn=generate,
126
+ additional_inputs=[
127
+ gr.File(label="Upload Images/Videos", file_types=["image", "video"], file_count="multiple")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  ],
129
+ description="# **cosmos-reason1 by nvidia**",
130
+ textbox=gr.Textbox(label="Prompt"),
131
+ cache_examples=False,
132
+ type="messages",
133
+ multimodal=True,
134
+ stop_btn="Stop Generation"
135
  )
136
+ demo.queue(max_size=10).launch(share=True)
137
 
138
  if __name__ == "__main__":
139
+ main()