Spaces:
Build error
Build error
Update inferencer.py
Browse files- inferencer.py +10 -10
inferencer.py
CHANGED
|
@@ -228,8 +228,8 @@ class InterleaveInferencer:
|
|
| 228 |
image_shapes=(1024, 1024), # Default, can be overridden by actual input image
|
| 229 |
):
|
| 230 |
gen_context = self.init_gen_context()
|
| 231 |
-
cfg_text_context =
|
| 232 |
-
cfg_img_context =
|
| 233 |
|
| 234 |
current_image_shapes = image_shapes
|
| 235 |
|
|
@@ -243,15 +243,16 @@ class InterleaveInferencer:
|
|
| 243 |
|
| 244 |
for input_term in input_lists:
|
| 245 |
if isinstance(input_term, str):
|
|
|
|
| 246 |
gen_context = self.update_context_text(input_term, gen_context)
|
| 247 |
-
cfg_text_context = self.update_context_text(input_term, cfg_text_context)
|
| 248 |
cfg_img_context = self.update_context_text(input_term, cfg_img_context)
|
|
|
|
| 249 |
elif isinstance(input_term, Image.Image):
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
cfg_text_context =
|
| 254 |
-
|
| 255 |
else:
|
| 256 |
raise ValueError(f"Unsupported input type: {type(input_term)}")
|
| 257 |
|
|
@@ -266,10 +267,9 @@ class InterleaveInferencer:
|
|
| 266 |
full_thought_text = "".join(thought_text_parts)
|
| 267 |
if full_thought_text: # Only update if thought was generated
|
| 268 |
gen_context = self.update_context_text(full_thought_text, gen_context)
|
| 269 |
-
cfg_text_context = self.update_context_text(full_thought_text, cfg_text_context)
|
| 270 |
|
| 271 |
img = self.gen_image(
|
| 272 |
-
image_shape=
|
| 273 |
gen_context=gen_context,
|
| 274 |
cfg_text_precontext=cfg_text_context,
|
| 275 |
cfg_img_precontext=cfg_img_context,
|
|
|
|
| 228 |
image_shapes=(1024, 1024), # Default, can be overridden by actual input image
|
| 229 |
):
|
| 230 |
gen_context = self.init_gen_context()
|
| 231 |
+
cfg_text_context = deepcopy(gen_context)
|
| 232 |
+
cfg_img_context = deepcopy(gen_context)
|
| 233 |
|
| 234 |
current_image_shapes = image_shapes
|
| 235 |
|
|
|
|
| 243 |
|
| 244 |
for input_term in input_lists:
|
| 245 |
if isinstance(input_term, str):
|
| 246 |
+
cfg_text_context = deepcopy(gen_context)
|
| 247 |
gen_context = self.update_context_text(input_term, gen_context)
|
|
|
|
| 248 |
cfg_img_context = self.update_context_text(input_term, cfg_img_context)
|
| 249 |
+
|
| 250 |
elif isinstance(input_term, Image.Image):
|
| 251 |
+
input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
|
| 252 |
+
gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
|
| 253 |
+
image_shapes = input_term.size[::-1]
|
| 254 |
+
cfg_text_context = deepcopy(gen_context)
|
| 255 |
+
|
| 256 |
else:
|
| 257 |
raise ValueError(f"Unsupported input type: {type(input_term)}")
|
| 258 |
|
|
|
|
| 267 |
full_thought_text = "".join(thought_text_parts)
|
| 268 |
if full_thought_text: # Only update if thought was generated
|
| 269 |
gen_context = self.update_context_text(full_thought_text, gen_context)
|
|
|
|
| 270 |
|
| 271 |
img = self.gen_image(
|
| 272 |
+
image_shape=image_shapes,
|
| 273 |
gen_context=gen_context,
|
| 274 |
cfg_text_precontext=cfg_text_context,
|
| 275 |
cfg_img_precontext=cfg_img_context,
|