prithivMLmods commited on
Commit
8110123
·
verified ·
1 Parent(s): 8aa0ea7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -165
app.py CHANGED
@@ -1,172 +1,234 @@
1
  import gradio as gr
2
- from transformers.image_utils import load_image
3
- from threading import Thread
4
- import time
5
- import torch
6
  import spaces
7
- import cv2
8
  import numpy as np
 
 
 
9
  from PIL import Image
10
- from transformers import (
11
- Qwen2VLForConditionalGeneration,
12
- AutoProcessor,
13
- TextIteratorStreamer,
14
- )
15
- from transformers import Qwen2_5_VLForConditionalGeneration
16
-
17
- # Helper Functions
18
- def progress_bar_html(label: str, primary_color: str = "#FF69B4", secondary_color: str = "#FFB6C1") -> str:
19
- """
20
- Returns an HTML snippet for a thin animated progress bar with a label.
21
- Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
22
- """
23
- return f'''
24
- <div style="display: flex; align-items: center;">
25
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
26
- <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
27
- <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
28
- </div>
29
- </div>
30
- <style>
31
- @keyframes loading {{
32
- 0% {{ transform: translateX(-100%); }}
33
- 100% {{ transform: translateX(100%); }}
34
- }}
35
- </style>
36
- '''
37
-
38
- def downsample_video(video_path):
39
- """
40
- Downsamples a video file by extracting 10 evenly spaced frames.
41
- Returns a list of tuples (PIL.Image, timestamp).
42
- """
43
- vidcap = cv2.VideoCapture(video_path)
44
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
45
- fps = vidcap.get(cv2.CAP_PROP_FPS)
46
- frames = []
47
- if total_frames <= 0 or fps <= 0:
48
- vidcap.release()
49
- return frames
50
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
51
- for i in frame_indices:
52
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
53
- success, image = vidcap.read()
54
- if success:
55
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
56
- pil_image = Image.fromarray(image)
57
- timestamp = round(i / fps, 2)
58
- frames.append((pil_image, timestamp))
59
- vidcap.release()
60
- return frames
61
-
62
- # Model and Processor Setup
63
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
64
- qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
65
- qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
66
- QV_MODEL_ID,
67
- trust_remote_code=True,
68
- torch_dtype=torch.float16
69
- ).to("cuda").eval()
70
-
71
- DOCSCOPEOCR_MODEL_ID = "prithivMLmods/docscopeOCR-7B-050425-exp"
72
- docscopeocr_processor = AutoProcessor.from_pretrained(DOCSCOPEOCR_MODEL_ID, trust_remote_code=True)
73
- docscopeocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
- DOCSCOPEOCR_MODEL_ID,
75
- trust_remote_code=True,
76
- torch_dtype=torch.bfloat16
77
- ).to("cuda").eval()
78
-
79
- # Main Inference Function
80
- @spaces.GPU
81
- def model_inference(message, history, use_docscopeocr):
82
- text = message["text"].strip()
83
- files = message.get("files", [])
84
-
85
- if not text and not files:
86
- yield "Error: Please input a text query or provide image or video files."
87
- return
88
-
89
- # Process files: images and videos
90
- image_list = []
91
- for idx, file in enumerate(files):
92
- if file.lower().endswith((".mp4", ".avi", ".mov")):
93
- frames = downsample_video(file)
94
- if not frames:
95
- yield "Error: Could not extract frames from the video."
96
- return
97
- for frame, timestamp in frames:
98
- label = f"Video {idx+1} Frame {timestamp}:"
99
- image_list.append((label, frame))
100
- else:
101
- try:
102
- img = load_image(file)
103
- label = f"Image {idx+1}:"
104
- image_list.append((label, img))
105
- except Exception as e:
106
- yield f"Error loading image: {str(e)}"
107
- return
108
-
109
- # Build content list
110
- content = [{"type": "text", "text": text}]
111
- for label, img in image_list:
112
- content.append({"type": "text", "text": label})
113
- content.append({"type": "image", "image": img})
114
-
115
- messages = [{"role": "user", "content": content}]
116
-
117
- # Select processor and model
118
- if use_docscopeocr:
119
- processor = docscopeocr_processor
120
- model = docscopeocr_model
121
- model_name = "DocScopeOCR"
122
- else:
123
- processor = qwen_processor
124
- model = qwen_model
125
- model_name = "Qwen2VL OCR"
126
-
127
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
- all_images = [item["image"] for item in content if item["type"] == "image"]
129
- inputs = processor(
130
- text=[prompt_full],
131
- images=all_images if all_images else None,
132
- return_tensors="pt",
133
- padding=True,
134
- ).to("cuda")
135
-
136
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
137
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
138
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
139
- thread.start()
140
- buffer = ""
141
- yield progress_bar_html(f"Processing with {model_name}")
142
- for new_text in streamer:
143
- buffer += new_text
144
- buffer = buffer.replace("<|im_end|>", "")
145
- time.sleep(0.01)
146
- yield buffer
147
-
148
- # Gradio Interface
149
  examples = [
150
- [{"text": "OCR the text in the image", "files": ["example/image1.jpg"]}],
151
- [{"text": "Describe the content of the image", "files": ["example/image2.jpg"]}],
152
- [{"text": "Extract the image content", "files": ["example/image3.jpg"]}],
 
153
  ]
154
 
155
- demo = gr.ChatInterface(
156
- fn=model_inference,
157
- description="# **DocScope OCR `VL/OCR`**",
158
- examples=examples,
159
- textbox=gr.MultimodalTextbox(
160
- label="Query Input",
161
- file_types=["image", "video"],
162
- file_count="multiple",
163
- placeholder="Input your query and optionally upload image(s) or video(s). Select the model using the checkbox."
164
- ),
165
- stop_btn="Stop Generation",
166
- multimodal=True,
167
- cache_examples=False,
168
- theme="bethecloud/storj_theme",
169
- additional_inputs=[gr.Checkbox(label="Use DocScopeOCR", value=True, info="Check to use DocScopeOCR, uncheck to use Qwen2VL OCR")],
170
- )
171
-
172
- demo.launch(debug=True, ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
2
  import spaces
 
3
  import numpy as np
4
+ import random
5
+ from diffusers import DiffusionPipeline
6
+ import torch
7
  from PIL import Image
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "stabilityai/stable-diffusion-3.5-large-turbo"
11
+
12
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
13
+
14
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
+ pipe = pipe.to(device)
16
+
17
+ pipe.load_lora_weights("strangerzonehf/SD3.5-Turbo-Portrait-LoRA", weight_name="SD3.5-Turbo-Portrait.safetensors")
18
+ trigger_word = "Turbo Portrait"
19
+ pipe.fuse_lora(lora_scale=1.0)
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+ MAX_IMAGE_SIZE = 1024
23
+
24
+ # Define styles
25
+ style_list = [
26
+ {
27
+ "name": "3840 x 2160",
28
+ "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
29
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
30
+ },
31
+ {
32
+ "name": "2560 x 1440",
33
+ "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
34
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
35
+ },
36
+ {
37
+ "name": "HD+",
38
+ "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
39
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
40
+ },
41
+ {
42
+ "name": "Style Zero",
43
+ "prompt": "{prompt}",
44
+ "negative_prompt": "",
45
+ },
46
+ ]
47
+
48
+ STYLE_NAMES = [style["name"] for style in style_list]
49
+ DEFAULT_STYLE_NAME = STYLE_NAMES[0]
50
+
51
+ grid_sizes = {
52
+ "2x1": (2, 1),
53
+ "1x2": (1, 2),
54
+ "2x2": (2, 2),
55
+ "2x3": (2, 3),
56
+ "3x2": (3, 2),
57
+ "1x1": (1, 1)
58
+ }
59
+
60
+ @spaces.GPU(duration=60)
61
+ def infer(
62
+ prompt,
63
+ negative_prompt="",
64
+ seed=42,
65
+ randomize_seed=False,
66
+ width=1024,
67
+ height=1024,
68
+ guidance_scale=7.5,
69
+ num_inference_steps=10,
70
+ style="Style Zero",
71
+ grid_size="1x1",
72
+ progress=gr.Progress(track_tqdm=True),
73
+ ):
74
+ selected_style = next(s for s in style_list if s["name"] == style)
75
+ styled_prompt = selected_style["prompt"].format(prompt=prompt)
76
+ styled_negative_prompt = selected_style["negative_prompt"]
77
+
78
+ if randomize_seed:
79
+ seed = random.randint(0, MAX_SEED)
80
+
81
+ generator = torch.Generator().manual_seed(seed)
82
+
83
+ grid_size_x, grid_size_y = grid_sizes.get(grid_size, (1, 1))
84
+ num_images = grid_size_x * grid_size_y
85
+
86
+ options = {
87
+ "prompt": styled_prompt,
88
+ "negative_prompt": styled_negative_prompt,
89
+ "guidance_scale": guidance_scale,
90
+ "num_inference_steps": num_inference_steps,
91
+ "width": width,
92
+ "height": height,
93
+ "generator": generator,
94
+ "num_images_per_prompt": num_images,
95
+ }
96
+
97
+ torch.cuda.empty_cache() # Clear GPU memory
98
+ result = pipe(**options)
99
+
100
+ grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
101
+
102
+ for i, img in enumerate(result.images[:num_images]):
103
+ grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
104
+
105
+ return grid_img, seed
106
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  examples = [
108
+ "A tiny astronaut hatching from an egg on the moon, 4k, planet theme",
109
+ "An anime-style illustration of a delicious, golden-brown wiener schnitzel on a plate, served with fresh lemon slices, parsley --style raw5",
110
+ "Cold coffee in a cup bokeh --ar 85:128 --v 6.0 --style raw5, 4K, Photo-Realistic",
111
+ "A cat holding a sign that says hello world --ar 85:128 --v 6.0 --style raw"
112
  ]
113
 
114
+ css = '''
115
+ .gradio-container{max-width: 585px !important}
116
+ h1{text-align:center}
117
+ footer {
118
+ visibility: hidden
119
+ }
120
+ '''
121
+
122
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
123
+ with gr.Column(elem_id="col-container"):
124
+ gr.Markdown("## GRID 6X🪨")
125
+
126
+ with gr.Row():
127
+ prompt = gr.Text(
128
+ label="Prompt",
129
+ show_label=False,
130
+ max_lines=1,
131
+ placeholder="Enter your prompt",
132
+ container=False,
133
+ )
134
+
135
+ run_button = gr.Button("Run", scale=0, variant="primary")
136
+
137
+ result = gr.Image(label="Result", show_label=False)
138
+
139
+
140
+ with gr.Row(visible=True):
141
+ grid_size_selection = gr.Dropdown(
142
+ choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
143
+ value="1x1",
144
+ label="Grid Size"
145
+ )
146
+
147
+ with gr.Accordion("Advanced Settings", open=False):
148
+ negative_prompt = gr.Text(
149
+ label="Negative prompt",
150
+ max_lines=1,
151
+ placeholder="Enter a negative prompt",
152
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
153
+ visible=False,
154
+ )
155
+
156
+ seed = gr.Slider(
157
+ label="Seed",
158
+ minimum=0,
159
+ maximum=MAX_SEED,
160
+ step=1,
161
+ value=0,
162
+ )
163
+
164
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
165
+
166
+ with gr.Row():
167
+ width = gr.Slider(
168
+ label="Width",
169
+ minimum=512,
170
+ maximum=MAX_IMAGE_SIZE,
171
+ step=32,
172
+ value=1024,
173
+ )
174
+
175
+ height = gr.Slider(
176
+ label="Height",
177
+ minimum=512,
178
+ maximum=MAX_IMAGE_SIZE,
179
+ step=32,
180
+ value=1024,
181
+ )
182
+
183
+ with gr.Row():
184
+ guidance_scale = gr.Slider(
185
+ label="Guidance scale",
186
+ minimum=0.0,
187
+ maximum=7.5,
188
+ step=0.1,
189
+ value=0.0,
190
+ )
191
+
192
+ num_inference_steps = gr.Slider(
193
+ label="Number of inference steps",
194
+ minimum=1,
195
+ maximum=50,
196
+ step=1,
197
+ value=8,
198
+ )
199
+
200
+ style_selection = gr.Radio(
201
+ show_label=True,
202
+ container=True,
203
+ interactive=True,
204
+ choices=STYLE_NAMES,
205
+ value=DEFAULT_STYLE_NAME,
206
+ label="Quality Style",
207
+ )
208
+
209
+ gr.Examples(examples=examples,
210
+ inputs=[prompt],
211
+ outputs=[result, seed],
212
+ fn=infer,
213
+ cache_examples=False)
214
+
215
+ gr.on(
216
+ triggers=[run_button.click, prompt.submit],
217
+ fn=infer,
218
+ inputs=[
219
+ prompt,
220
+ negative_prompt,
221
+ seed,
222
+ randomize_seed,
223
+ width,
224
+ height,
225
+ guidance_scale,
226
+ num_inference_steps,
227
+ style_selection,
228
+ grid_size_selection,
229
+ ],
230
+ outputs=[result, seed],
231
+ )
232
+
233
+ if __name__ == "__main__":
234
+ demo.launch(ssr_mode=False, show_error=True)