Update src/pipeline.py
Browse files- src/pipeline.py +8 -1
src/pipeline.py
CHANGED
@@ -53,9 +53,16 @@ def _load_vae_model():
|
|
53 |
|
54 |
def _load_transformer_model():
|
55 |
"""Load the transformer model from a specific cached path."""
|
|
|
|
|
|
|
|
|
56 |
transformer_path = os.path.join(
|
57 |
HF_HUB_CACHE,
|
58 |
-
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
|
|
|
|
|
|
|
59 |
)
|
60 |
return FluxTransformer2DModel.from_pretrained(
|
61 |
transformer_path,
|
|
|
53 |
|
54 |
def _load_transformer_model():
|
55 |
"""Load the transformer model from a specific cached path."""
|
56 |
+
# transformer_path = os.path.join(
|
57 |
+
# HF_HUB_CACHE,"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
|
58 |
+
|
59 |
+
# )
|
60 |
transformer_path = os.path.join(
|
61 |
HF_HUB_CACHE,
|
62 |
+
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
|
63 |
+
"snapshots",
|
64 |
+
model_revision,
|
65 |
+
"transformer"
|
66 |
)
|
67 |
return FluxTransformer2DModel.from_pretrained(
|
68 |
transformer_path,
|