Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -34,14 +34,14 @@ from torchvision import transforms
|
|
34 |
from models.controlnet import ControlNetModel
|
35 |
from models.unet_2d_condition import UNet2DConditionModel
|
36 |
|
37 |
-
VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
|
38 |
|
39 |
-
vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
)
|
44 |
-
vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
|
45 |
|
46 |
def _generate_vlm_prompt(
|
47 |
vlm_model: Qwen2_5_VLForConditionalGeneration,
|
@@ -186,18 +186,20 @@ def process(
|
|
186 |
latent_tiled_overlap = 4,
|
187 |
sample_times = 1,
|
188 |
) -> List[np.ndarray]:
|
|
|
|
|
189 |
|
190 |
process_size = 512
|
191 |
resize_preproc = transforms.Compose([
|
192 |
transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
|
193 |
])
|
194 |
-
user_prompt = _generate_vlm_prompt(
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
|
202 |
# with torch.no_grad():
|
203 |
seed_everything(seed)
|
|
|
34 |
from models.controlnet import ControlNetModel
|
35 |
from models.unet_2d_condition import UNet2DConditionModel
|
36 |
|
37 |
+
# VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
|
38 |
|
39 |
+
# vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
40 |
+
# VLM_NAME,
|
41 |
+
# torch_dtype="auto",
|
42 |
+
# device_map="auto" # immediately dispatches layers onto available GPUs
|
43 |
+
# )
|
44 |
+
# vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
|
45 |
|
46 |
def _generate_vlm_prompt(
|
47 |
vlm_model: Qwen2_5_VLForConditionalGeneration,
|
|
|
186 |
latent_tiled_overlap = 4,
|
187 |
sample_times = 1,
|
188 |
) -> List[np.ndarray]:
|
189 |
+
|
190 |
+
input_image = input_image.resize(256, 256)
|
191 |
|
192 |
process_size = 512
|
193 |
resize_preproc = transforms.Compose([
|
194 |
transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
|
195 |
])
|
196 |
+
# user_prompt = _generate_vlm_prompt(
|
197 |
+
# vlm_model=vlm_model,
|
198 |
+
# vlm_processor=vlm_processor,
|
199 |
+
# process_vision_info=process_vision_info,
|
200 |
+
# pil_image=input_image,
|
201 |
+
# device=device,
|
202 |
+
# )
|
203 |
|
204 |
# with torch.no_grad():
|
205 |
seed_everything(seed)
|