prithivMLmods commited on
Commit
1d9dc27
·
verified ·
1 Parent(s): c6a1ef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -104
app.py CHANGED
@@ -79,125 +79,132 @@ def progress_bar_html(label: str) -> str:
79
  '''
80
 
81
  @spaces.GPU
82
- def generate(text: str, files: list,
83
- max_new_tokens: int = 1024,
84
- temperature: float = 0.6,
85
- top_p: float = 0.9,
86
- top_k: int = 50,
87
- repetition_penalty: float = 1.2):
88
  """
89
- Generates responses using the Qwen2VL model for image and video inputs.
90
- - If images are provided, performs image inference.
91
- - If videos are provided, performs video inference by downsampling to frames.
92
  """
93
- if not files:
94
- yield "Please upload an image or video for inference."
95
  return
96
 
97
- # Determine if the files are images or videos
98
- image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
99
- video_files = [f for f in files if f.lower().endswith(('.mp4', '.avi', '.mov', '.mkv'))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- if image_files and video_files:
102
- yield "Please upload either images or videos, not both."
 
 
 
 
 
 
 
 
 
 
103
  return
104
 
105
- if image_files:
106
- # Image inference
107
- images = [load_image(image) for image in image_files]
108
- messages = [{
109
- "role": "user",
110
- "content": [
111
- *[{"type": "image", "image": image} for image in images],
112
- {"type": "text", "text": text},
113
- ]
114
- }]
115
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
116
- inputs = processor(
117
- text=[prompt_full],
118
- images=images,
119
- return_tensors="pt",
120
- padding=True,
121
- truncation=True,
122
- max_length=MAX_INPUT_TOKEN_LENGTH
123
- ).to("cuda")
124
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
125
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
126
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
127
- thread.start()
128
- buffer = ""
129
- yield progress_bar_html("Processing images with cosmos-reasoning")
130
- for new_text in streamer:
131
- buffer += new_text
132
- buffer = buffer.replace("<|im_end|>", "")
133
- time.sleep(0.01)
134
- yield buffer
135
- elif video_files:
136
- # Video inference
137
- video_path = video_files[0] # Assuming only one video is uploaded
138
- frames = downsample_video(video_path)
139
- messages = [
140
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
141
- {"role": "user", "content": [{"type": "text", "text": text}]}
142
- ]
143
- # Append each frame with its timestamp.
144
- for frame in frames:
145
- image, timestamp = frame
146
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
147
- image.save(image_path)
148
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
149
- messages[1]["content"].append({"type": "image", "url": image_path})
150
- inputs = processor.apply_chat_template(
151
- messages,
152
- tokenize=True,
153
- add_generation_prompt=True,
154
- return_dict=True,
155
- return_tensors="pt",
156
- truncation=True,
157
- max_length=MAX_INPUT_TOKEN_LENGTH
158
- ).to("cuda")
159
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
160
- generation_kwargs = {
161
- **inputs,
162
- "streamer": streamer,
163
- "max_new_tokens": max_new_tokens,
164
- "do_sample": True,
165
- "temperature": temperature,
166
- "top_p": top_p,
167
- "top_k": top_k,
168
- "repetition_penalty": repetition_penalty,
169
- }
170
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
171
- thread.start()
172
- buffer = ""
173
- yield progress_bar_html("Processing video with cosmos-reasoning")
174
- for new_text in streamer:
175
- buffer += new_text
176
- buffer = buffer.replace("<|im_end|>", "")
177
- time.sleep(0.01)
178
- yield buffer
179
- else:
180
- yield "Unsupported file type. Please upload images or videos."
181
 
182
  # Create the Gradio Interface
183
  with gr.Blocks() as demo:
184
- gr.Markdown("# **cosmos-reason1 by nvidia**")
185
  with gr.Row():
186
  with gr.Column():
187
- text_input = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
188
- file_input = gr.File(label="Upload Image or Video", file_types=["image", "video"], file_count="multiple")
189
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
190
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
191
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
192
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
193
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
194
- submit_btn = gr.Button("Submit")
 
 
 
 
 
 
 
195
  with gr.Column():
196
  output = gr.Textbox(label="Output", interactive=False)
197
 
198
- submit_btn.click(
199
- fn=generate,
200
- inputs=[text_input, file_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
 
 
 
 
 
201
  outputs=output
202
  )
203
 
 
79
  '''
80
 
81
  @spaces.GPU
82
+ def generate_image(text: str, image: Image.Image,
83
+ max_new_tokens: int = 1024,
84
+ temperature: float = 0.6,
85
+ top_p: float = 0.9,
86
+ top_k: int = 50,
87
+ repetition_penalty: float = 1.2):
88
  """
89
+ Generates responses using the Cosmos-Reason1 model for image input.
 
 
90
  """
91
+ if image is None:
92
+ yield "Please upload an image."
93
  return
94
 
95
+ messages = [{
96
+ "role": "user",
97
+ "content": [
98
+ {"type": "image", "image": image},
99
+ {"type": "text", "text": text},
100
+ ]
101
+ }]
102
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
103
+ inputs = processor(
104
+ text=[prompt_full],
105
+ images=[image],
106
+ return_tensors="pt",
107
+ padding=True,
108
+ truncation=True,
109
+ max_length=MAX_INPUT_TOKEN_LENGTH
110
+ ).to("cuda")
111
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
112
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
113
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
114
+ thread.start()
115
+ buffer = ""
116
+ yield progress_bar_html("Processing image with Cosmos-Reason1")
117
+ for new_text in streamer:
118
+ buffer += new_text
119
+ buffer = buffer.replace("<|im_end|>", "")
120
+ time.sleep(0.01)
121
+ yield buffer
122
 
123
+ @spaces.GPU
124
+ def generate_video(text: str, video_path: str,
125
+ max_new_tokens: int = 1024,
126
+ temperature: float = 0.6,
127
+ top_p: float = 0.9,
128
+ top_k: int = 50,
129
+ repetition_penalty: float = 1.2):
130
+ """
131
+ Generates responses using the Cosmos-Reason1 model for video input.
132
+ """
133
+ if video_path is None:
134
+ yield "Please upload a video."
135
  return
136
 
137
+ frames = downsample_video(video_path)
138
+ messages = [
139
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
140
+ {"role": "user", "content": [{"type": "text", "text": text}]}
141
+ ]
142
+ # Append each frame with its timestamp.
143
+ for frame in frames:
144
+ image, timestamp = frame
145
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
146
+ messages[1]["content"].append({"type": "image", "image": image})
147
+ inputs = processor.apply_chat_template(
148
+ messages,
149
+ tokenize=True,
150
+ add_generation_prompt=True,
151
+ return_dict=True,
152
+ return_tensors="pt",
153
+ truncation=True,
154
+ max_length=MAX_INPUT_TOKEN_LENGTH
155
+ ).to("cuda")
156
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
157
+ generation_kwargs = {
158
+ **inputs,
159
+ "streamer": streamer,
160
+ "max_new_tokens": max_new_tokens,
161
+ "do_sample": True,
162
+ "temperature": temperature,
163
+ "top_p": top_p,
164
+ "top_k": top_k,
165
+ "repetition_penalty": repetition_penalty,
166
+ }
167
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
168
+ thread.start()
169
+ buffer = ""
170
+ yield progress_bar_html("Processing video with Cosmos-Reason1")
171
+ for new_text in streamer:
172
+ buffer += new_text
173
+ buffer = buffer.replace("<|im_end|>", "")
174
+ time.sleep(0.01)
175
+ yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  # Create the Gradio Interface
178
  with gr.Blocks() as demo:
179
+ gr.Markdown("# **Cosmos-Reason1 by NVIDIA**")
180
  with gr.Row():
181
  with gr.Column():
182
+ with gr.Tabs():
183
+ with gr.TabItem("Image Inference"):
184
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
185
+ image_upload = gr.Image(type="pil", label="Upload Image")
186
+ image_submit = gr.Button("Submit")
187
+ with gr.TabItem("Video Inference"):
188
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
189
+ video_upload = gr.Video(label="Upload Video")
190
+ video_submit = gr.Button("Submit")
191
+ with gr.Accordion("Advanced options", open=False):
192
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
193
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
194
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
195
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
196
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
197
  with gr.Column():
198
  output = gr.Textbox(label="Output", interactive=False)
199
 
200
+ image_submit.click(
201
+ fn=generate_image,
202
+ inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
203
+ outputs=output
204
+ )
205
+ video_submit.click(
206
+ fn=generate_video,
207
+ inputs=[video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
208
  outputs=output
209
  )
210