Spaces:
Runtime error
Runtime error
Commit
·
2a072c6
1
Parent(s):
ef5107a
Auto deploy
Browse files- stCompressService.py +9 -9
stCompressService.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
import pathlib
|
| 3 |
import torch
|
| 4 |
import torch.hub
|
| 5 |
from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
|
|
@@ -29,11 +28,10 @@ def loadModel(device):
|
|
| 29 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
| 30 |
|
| 31 |
config = Config.deserialize(ckpt["config"])
|
| 32 |
-
model = Compressor(**config.Model.Params).to(device)
|
| 33 |
model.QuantizationParameter = "qp_2_msssim"
|
| 34 |
model.load_state_dict(ckpt["model"])
|
| 35 |
-
return model
|
| 36 |
-
|
| 37 |
|
| 38 |
|
| 39 |
@st.cache
|
|
@@ -46,8 +44,9 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
|
|
| 46 |
# [c, h, w]
|
| 47 |
image = (image - 0.5) * 2
|
| 48 |
|
| 49 |
-
with model.
|
| 50 |
-
codes,
|
|
|
|
| 51 |
|
| 52 |
return File(headers[0], binaries[0])
|
| 53 |
|
|
@@ -56,9 +55,10 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
|
|
| 56 |
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
| 57 |
binaries = sourceFile.Content
|
| 58 |
|
| 59 |
-
with model.
|
|
|
|
| 60 |
# [1, c, h, w]
|
| 61 |
-
restored = model.
|
| 62 |
|
| 63 |
# [c, h, w]
|
| 64 |
return DeTransform()(restored[0])
|
|
@@ -71,7 +71,7 @@ def main():
|
|
| 71 |
else:
|
| 72 |
device = torch.device("cuda")
|
| 73 |
|
| 74 |
-
model = loadModel(device)
|
| 75 |
|
| 76 |
st.sidebar.markdown("""
|
| 77 |
<p align="center">
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.hub
|
| 4 |
from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
|
|
|
|
| 28 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
| 29 |
|
| 30 |
config = Config.deserialize(ckpt["config"])
|
| 31 |
+
model = Compressor(**config.Model.Params).to(device).eval()
|
| 32 |
model.QuantizationParameter = "qp_2_msssim"
|
| 33 |
model.load_state_dict(ckpt["model"])
|
| 34 |
+
return torch.jit.script(model)
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
@st.cache
|
|
|
|
| 44 |
# [c, h, w]
|
| 45 |
image = (image - 0.5) * 2
|
| 46 |
|
| 47 |
+
with model.readyForCoding() as cdfs:
|
| 48 |
+
codes, size = model.encode(image[None, ...])
|
| 49 |
+
binaries, headers = model.compress(codes, size, cdfs)
|
| 50 |
|
| 51 |
return File(headers[0], binaries[0])
|
| 52 |
|
|
|
|
| 55 |
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
| 56 |
binaries = sourceFile.Content
|
| 57 |
|
| 58 |
+
with model.readyForCoding() as cdfs:
|
| 59 |
+
codes, imageSize = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
|
| 60 |
# [1, c, h, w]
|
| 61 |
+
restored = model.decode(codes, imageSize)
|
| 62 |
|
| 63 |
# [c, h, w]
|
| 64 |
return DeTransform()(restored[0])
|
|
|
|
| 71 |
else:
|
| 72 |
device = torch.device("cuda")
|
| 73 |
|
| 74 |
+
model = loadModel(device)
|
| 75 |
|
| 76 |
st.sidebar.markdown("""
|
| 77 |
<p align="center">
|