Spaces:
Sleeping
Sleeping
Upload 948 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- TryYours-Virtual-Try-On/.gitignore +5 -0
- TryYours-Virtual-Try-On/Demo.ipynb +267 -0
- TryYours-Virtual-Try-On/Graphonomy-master/LICENSE +21 -0
- TryYours-Virtual-Try-On/Graphonomy-master/README.md +124 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__init__.py +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/atr.py +109 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp.py +107 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp_pascal_atr.py +219 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/custom_transforms.py +491 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_atr.py +8 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_cihp.py +8 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_pascal.py +8 -0
- TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/pascal.py +106 -0
- TryYours-Virtual-Try-On/Graphonomy-master/eval_cihp.sh +5 -0
- TryYours-Virtual-Try-On/Graphonomy-master/eval_pascal.sh +5 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/.ipynb_checkpoints/inference-checkpoint.py +203 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/inference.py +206 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/test/__init__.py +3 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_cihp2pascal.py +268 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_pascal2cihp.py +268 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/test/test_from_disk.py +65 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/transfer/train_cihp_from_pascal.py +331 -0
- TryYours-Virtual-Try-On/Graphonomy-master/exp/universal/pascal_atr_cihp_uni.py +493 -0
- TryYours-Virtual-Try-On/Graphonomy-master/inference.sh +1 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__init__.py +3 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-310.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-39.pyc +0 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception.py +684 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_synBN.py +596 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_transfer.py +1003 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_universal.py +1077 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/gcn.py +279 -0
- TryYours-Virtual-Try-On/Graphonomy-master/networks/graph.py +261 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
TryYours-Virtual-Try-On/HR-VITON-main/figures/fig.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
TryYours-Virtual-Try-On/HR-VITON-main/Output/00001_00_00001_00.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
TryYours-Virtual-Try-On/static/finalimg.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
TryYours-Virtual-Try-On/static/origin_web.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
TryYours-Virtual-Try-On/TryYours_presentation_kr.pdf filter=lfs diff=lfs merge=lfs -text
|
TryYours-Virtual-Try-On/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# don't upload large pth files
|
2 |
+
|
3 |
+
HR-VITON-main/gen.pth
|
4 |
+
HR-VITON-main/mtviton.pth
|
5 |
+
Graphonomy-master/inference.pth
|
TryYours-Virtual-Try-On/Demo.ipynb
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "803b037a-183e-46e4-bfde-a144dc779cf0",
|
7 |
+
"metadata": {
|
8 |
+
"execution": {
|
9 |
+
"iopub.execute_input": "2022-12-29T07:03:01.120935Z",
|
10 |
+
"iopub.status.busy": "2022-12-29T07:03:01.120639Z",
|
11 |
+
"iopub.status.idle": "2022-12-29T07:03:03.538100Z",
|
12 |
+
"shell.execute_reply": "2022-12-29T07:03:03.537115Z",
|
13 |
+
"shell.execute_reply.started": "2022-12-29T07:03:01.120878Z"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"outputs": [
|
17 |
+
{
|
18 |
+
"name": "stdout",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"Collecting tensorboardX\n",
|
22 |
+
" Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)\n",
|
23 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m125.4/125.4 kB\u001b[0m \u001b[31m23.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
24 |
+
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from tensorboardX) (1.23.1)\n",
|
25 |
+
"Requirement already satisfied: protobuf<=3.20.1,>=3.8.0 in /usr/local/lib/python3.9/dist-packages (from tensorboardX) (3.19.4)\n",
|
26 |
+
"Installing collected packages: tensorboardX\n",
|
27 |
+
"Successfully installed tensorboardX-2.5.1\n",
|
28 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
29 |
+
"\u001b[0m"
|
30 |
+
]
|
31 |
+
}
|
32 |
+
],
|
33 |
+
"source": [
|
34 |
+
"!pip install tensorboardX av torchgeometry flask flask-ngrok iglovikov_helper_functions cloths_segmentation albumentations"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 8,
|
40 |
+
"id": "756d9b98-41d4-4f4f-85bc-51c77b99507c",
|
41 |
+
"metadata": {
|
42 |
+
"execution": {
|
43 |
+
"iopub.execute_input": "2022-12-29T07:07:46.679818Z",
|
44 |
+
"iopub.status.busy": "2022-12-29T07:07:46.679527Z",
|
45 |
+
"iopub.status.idle": "2022-12-29T07:10:34.279721Z",
|
46 |
+
"shell.execute_reply": "2022-12-29T07:10:34.278988Z",
|
47 |
+
"shell.execute_reply.started": "2022-12-29T07:07:46.679796Z"
|
48 |
+
}
|
49 |
+
},
|
50 |
+
"outputs": [
|
51 |
+
{
|
52 |
+
"name": "stdout",
|
53 |
+
"output_type": "stream",
|
54 |
+
"text": [
|
55 |
+
"Collecting git+https://github.com/facebookresearch/detectron2.git\n",
|
56 |
+
" Cloning https://github.com/facebookresearch/detectron2.git to /tmp/pip-req-build-llo0z_rd\n",
|
57 |
+
" Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/detectron2.git /tmp/pip-req-build-llo0z_rd\n",
|
58 |
+
" Resolved https://github.com/facebookresearch/detectron2.git to commit 857d5de21a7789d1bba46694cf608b1cb2ea128a\n",
|
59 |
+
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
60 |
+
"\u001b[?25hRequirement already satisfied: Pillow>=7.1 in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (9.2.0)\n",
|
61 |
+
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (3.5.2)\n",
|
62 |
+
"Collecting pycocotools>=2.0.2\n",
|
63 |
+
" Downloading pycocotools-2.0.6.tar.gz (24 kB)\n",
|
64 |
+
" Installing build dependencies ... \u001b[?25ldone\n",
|
65 |
+
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
|
66 |
+
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
67 |
+
"\u001b[?25hRequirement already satisfied: termcolor>=1.1 in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (1.1.0)\n",
|
68 |
+
"Collecting yacs>=0.1.8\n",
|
69 |
+
" Downloading yacs-0.1.8-py3-none-any.whl (14 kB)\n",
|
70 |
+
"Collecting tabulate\n",
|
71 |
+
" Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)\n",
|
72 |
+
"Requirement already satisfied: cloudpickle in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (2.1.0)\n",
|
73 |
+
"Requirement already satisfied: tqdm>4.29.0 in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (4.64.0)\n",
|
74 |
+
"Requirement already satisfied: tensorboard in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (2.9.1)\n",
|
75 |
+
"Collecting fvcore<0.1.6,>=0.1.5\n",
|
76 |
+
" Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)\n",
|
77 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.2/50.2 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
78 |
+
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
79 |
+
"\u001b[?25hCollecting iopath<0.1.10,>=0.1.7\n",
|
80 |
+
" Downloading iopath-0.1.9-py3-none-any.whl (27 kB)\n",
|
81 |
+
"Collecting omegaconf>=2.1\n",
|
82 |
+
" Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)\n",
|
83 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.5/79.5 kB\u001b[0m \u001b[31m20.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
84 |
+
"\u001b[?25hCollecting hydra-core>=1.1\n",
|
85 |
+
" Downloading hydra_core-1.3.1-py3-none-any.whl (154 kB)\n",
|
86 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.1/154.1 kB\u001b[0m \u001b[31m31.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
87 |
+
"\u001b[?25hCollecting black\n",
|
88 |
+
" Downloading black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n",
|
89 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m82.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
90 |
+
"\u001b[?25hCollecting timm\n",
|
91 |
+
" Downloading timm-0.6.12-py3-none-any.whl (549 kB)\n",
|
92 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m549.1/549.1 kB\u001b[0m \u001b[31m51.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
93 |
+
"\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (21.3)\n",
|
94 |
+
"Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from fvcore<0.1.6,>=0.1.5->detectron2==0.6) (1.23.1)\n",
|
95 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from fvcore<0.1.6,>=0.1.5->detectron2==0.6) (5.4.1)\n",
|
96 |
+
"Collecting antlr4-python3-runtime==4.9.*\n",
|
97 |
+
" Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)\n",
|
98 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.0/117.0 kB\u001b[0m \u001b[31m28.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
99 |
+
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
100 |
+
"\u001b[?25hCollecting portalocker\n",
|
101 |
+
" Downloading portalocker-2.6.0-py2.py3-none-any.whl (15 kB)\n",
|
102 |
+
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (0.11.0)\n",
|
103 |
+
"Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (3.0.9)\n",
|
104 |
+
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (2.8.2)\n",
|
105 |
+
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (4.34.4)\n",
|
106 |
+
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (1.4.3)\n",
|
107 |
+
"Collecting mypy-extensions>=0.4.3\n",
|
108 |
+
" Downloading mypy_extensions-0.4.3-py2.py3-none-any.whl (4.5 kB)\n",
|
109 |
+
"Collecting pathspec>=0.9.0\n",
|
110 |
+
" Downloading pathspec-0.10.3-py3-none-any.whl (29 kB)\n",
|
111 |
+
"Collecting platformdirs>=2\n",
|
112 |
+
" Downloading platformdirs-2.6.2-py3-none-any.whl (14 kB)\n",
|
113 |
+
"Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from black->detectron2==0.6) (8.1.3)\n",
|
114 |
+
"Collecting tomli>=1.1.0\n",
|
115 |
+
" Downloading tomli-2.0.1-py3-none-any.whl (12 kB)\n",
|
116 |
+
"Requirement already satisfied: typing-extensions>=3.10.0.0 in /usr/local/lib/python3.9/dist-packages (from black->detectron2==0.6) (4.3.0)\n",
|
117 |
+
"Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (0.6.1)\n",
|
118 |
+
"Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (2.2.2)\n",
|
119 |
+
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (1.8.1)\n",
|
120 |
+
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (3.3.7)\n",
|
121 |
+
"Requirement already satisfied: protobuf<3.20,>=3.9.2 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (3.19.4)\n",
|
122 |
+
"Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (63.1.0)\n",
|
123 |
+
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (2.28.1)\n",
|
124 |
+
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (0.4.6)\n",
|
125 |
+
"Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (2.9.0)\n",
|
126 |
+
"Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (1.47.0)\n",
|
127 |
+
"Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (0.35.1)\n",
|
128 |
+
"Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (1.1.0)\n",
|
129 |
+
"Requirement already satisfied: torchvision in /usr/local/lib/python3.9/dist-packages (from timm->detectron2==0.6) (0.13.0+cu116)\n",
|
130 |
+
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.9/dist-packages (from timm->detectron2==0.6) (0.8.1)\n",
|
131 |
+
"Requirement already satisfied: torch>=1.7 in /usr/local/lib/python3.9/dist-packages (from timm->detectron2==0.6) (1.12.0+cu116)\n",
|
132 |
+
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.9/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (0.2.8)\n",
|
133 |
+
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (5.2.0)\n",
|
134 |
+
"Requirement already satisfied: six>=1.9.0 in /usr/lib/python3/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (1.14.0)\n",
|
135 |
+
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/lib/python3/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (4.0)\n",
|
136 |
+
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.9/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard->detectron2==0.6) (1.3.1)\n",
|
137 |
+
"Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.9/dist-packages (from markdown>=2.6.8->tensorboard->detectron2==0.6) (4.12.0)\n",
|
138 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (1.26.10)\n",
|
139 |
+
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.9/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (2.1.0)\n",
|
140 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (2.8)\n",
|
141 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (2019.11.28)\n",
|
142 |
+
"Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.9/dist-packages (from werkzeug>=1.0.1->tensorboard->detectron2==0.6) (2.1.1)\n",
|
143 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub->timm->detectron2==0.6) (3.7.1)\n",
|
144 |
+
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard->detectron2==0.6) (3.8.1)\n",
|
145 |
+
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.9/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (0.4.8)\n",
|
146 |
+
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.9/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard->detectron2==0.6) (3.2.0)\n",
|
147 |
+
"Building wheels for collected packages: detectron2, fvcore, antlr4-python3-runtime, pycocotools\n",
|
148 |
+
" Building wheel for detectron2 (setup.py) ... \u001b[?25ldone\n",
|
149 |
+
"\u001b[?25h Created wheel for detectron2: filename=detectron2-0.6-cp39-cp39-linux_x86_64.whl size=5889223 sha256=ed1f701e1bc42d870d0d6fafd338e2a8e663f2847c8960619ecf38a4db0338ad\n",
|
150 |
+
" Stored in directory: /tmp/pip-ephem-wheel-cache-bzzpgw82/wheels/59/b4/83/84bfca751fa4dcc59998468be8688eb50e97408a83af171d42\n",
|
151 |
+
" Building wheel for fvcore (setup.py) ... \u001b[?25ldone\n",
|
152 |
+
"\u001b[?25h Created wheel for fvcore: filename=fvcore-0.1.5.post20221221-py3-none-any.whl size=61406 sha256=669c4db5ec6509578bb39829693a8e5438fde1d0116ec668be64bf0261e0eb9f\n",
|
153 |
+
" Stored in directory: /root/.cache/pip/wheels/83/42/02/66178d16e5c44dc26d309931834956baeda371956e86fbd876\n",
|
154 |
+
" Building wheel for antlr4-python3-runtime (setup.py) ... \u001b[?25ldone\n",
|
155 |
+
"\u001b[?25h Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144554 sha256=ddfc41c7e0307f85e68b773551e1be87484cbe4ac68ba04bd4e2dc501dfb6c95\n",
|
156 |
+
" Stored in directory: /root/.cache/pip/wheels/23/cf/80/f3efa822e6ab23277902ee9165fe772eeb1dfb8014f359020a\n",
|
157 |
+
" Building wheel for pycocotools (pyproject.toml) ... \u001b[?25ldone\n",
|
158 |
+
"\u001b[?25h Created wheel for pycocotools: filename=pycocotools-2.0.6-cp39-cp39-linux_x86_64.whl size=400228 sha256=a356c0450d2148ac9d590e3a6adc7b6e913d050be851a84ec830219795200de9\n",
|
159 |
+
" Stored in directory: /root/.cache/pip/wheels/2f/58/25/e78f1f766e904a9071266661d20d0bc6644df86bcd160aba11\n",
|
160 |
+
"Successfully built detectron2 fvcore antlr4-python3-runtime pycocotools\n",
|
161 |
+
"Installing collected packages: mypy-extensions, antlr4-python3-runtime, yacs, tomli, tabulate, portalocker, platformdirs, pathspec, omegaconf, iopath, hydra-core, black, timm, pycocotools, fvcore, detectron2\n",
|
162 |
+
"Successfully installed antlr4-python3-runtime-4.9.3 black-22.12.0 detectron2-0.6 fvcore-0.1.5.post20221221 hydra-core-1.3.1 iopath-0.1.9 mypy-extensions-0.4.3 omegaconf-2.3.0 pathspec-0.10.3 platformdirs-2.6.2 portalocker-2.6.0 pycocotools-2.0.6 tabulate-0.9.0 timm-0.6.12 tomli-2.0.1 yacs-0.1.8\n",
|
163 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
164 |
+
"\u001b[0m"
|
165 |
+
]
|
166 |
+
}
|
167 |
+
],
|
168 |
+
"source": [
|
169 |
+
"!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": 6,
|
175 |
+
"id": "ba6c8b48-7243-4602-8c62-862f7a8f00cb",
|
176 |
+
"metadata": {
|
177 |
+
"execution": {
|
178 |
+
"iopub.execute_input": "2022-12-29T07:03:17.515189Z",
|
179 |
+
"iopub.status.busy": "2022-12-29T07:03:17.514793Z",
|
180 |
+
"iopub.status.idle": "2022-12-29T07:03:21.008057Z",
|
181 |
+
"shell.execute_reply": "2022-12-29T07:03:21.006831Z",
|
182 |
+
"shell.execute_reply.started": "2022-12-29T07:03:17.515153Z"
|
183 |
+
}
|
184 |
+
},
|
185 |
+
"outputs": [
|
186 |
+
{
|
187 |
+
"name": "stdout",
|
188 |
+
"output_type": "stream",
|
189 |
+
"text": [
|
190 |
+
"Collecting pyngrok==4.1.1\n",
|
191 |
+
" Downloading pyngrok-4.1.1.tar.gz (18 kB)\n",
|
192 |
+
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
|
193 |
+
"\u001b[?25hRequirement already satisfied: future in /usr/lib/python3/dist-packages (from pyngrok==4.1.1) (0.18.2)\n",
|
194 |
+
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.9/dist-packages (from pyngrok==4.1.1) (5.4.1)\n",
|
195 |
+
"Building wheels for collected packages: pyngrok\n",
|
196 |
+
" Building wheel for pyngrok (setup.py) ... \u001b[?25ldone\n",
|
197 |
+
"\u001b[?25h Created wheel for pyngrok: filename=pyngrok-4.1.1-py3-none-any.whl size=15965 sha256=3669af38b11fcc66e95001ad9cdd77ad1da63d79252e42fe38331b06321ef59d\n",
|
198 |
+
" Stored in directory: /root/.cache/pip/wheels/89/2d/c2/abe6bcfde6bce368c00ecd73310c11edb672c3eda09a090cfa\n",
|
199 |
+
"Successfully built pyngrok\n",
|
200 |
+
"Installing collected packages: pyngrok\n",
|
201 |
+
"Successfully installed pyngrok-4.1.1\n",
|
202 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
203 |
+
"\u001b[0m"
|
204 |
+
]
|
205 |
+
}
|
206 |
+
],
|
207 |
+
"source": [
|
208 |
+
"!pip install pyngrok==4.1.1"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": 7,
|
214 |
+
"id": "e163eab2-841c-4372-8f75-3d60cbbe247a",
|
215 |
+
"metadata": {
|
216 |
+
"execution": {
|
217 |
+
"iopub.execute_input": "2022-12-29T07:03:21.010722Z",
|
218 |
+
"iopub.status.busy": "2022-12-29T07:03:21.009782Z",
|
219 |
+
"iopub.status.idle": "2022-12-29T07:03:22.584809Z",
|
220 |
+
"shell.execute_reply": "2022-12-29T07:03:22.583859Z",
|
221 |
+
"shell.execute_reply.started": "2022-12-29T07:03:21.010677Z"
|
222 |
+
}
|
223 |
+
},
|
224 |
+
"outputs": [
|
225 |
+
{
|
226 |
+
"name": "stdout",
|
227 |
+
"output_type": "stream",
|
228 |
+
"text": [
|
229 |
+
"Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml \n"
|
230 |
+
]
|
231 |
+
}
|
232 |
+
],
|
233 |
+
"source": [
|
234 |
+
"!ngrok authtoken yourtoken"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"execution_count": null,
|
240 |
+
"id": "70b46c0a-7f56-4cb0-a84c-836caf911ca4",
|
241 |
+
"metadata": {},
|
242 |
+
"outputs": [],
|
243 |
+
"source": []
|
244 |
+
}
|
245 |
+
],
|
246 |
+
"metadata": {
|
247 |
+
"kernelspec": {
|
248 |
+
"display_name": "Python 3 (ipykernel)",
|
249 |
+
"language": "python",
|
250 |
+
"name": "python3"
|
251 |
+
},
|
252 |
+
"language_info": {
|
253 |
+
"codemirror_mode": {
|
254 |
+
"name": "ipython",
|
255 |
+
"version": 3
|
256 |
+
},
|
257 |
+
"file_extension": ".py",
|
258 |
+
"mimetype": "text/x-python",
|
259 |
+
"name": "python",
|
260 |
+
"nbconvert_exporter": "python",
|
261 |
+
"pygments_lexer": "ipython3",
|
262 |
+
"version": "3.9.13"
|
263 |
+
}
|
264 |
+
},
|
265 |
+
"nbformat": 4,
|
266 |
+
"nbformat_minor": 5
|
267 |
+
}
|
TryYours-Virtual-Try-On/Graphonomy-master/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2016 Vladimir Nekrasov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
TryYours-Virtual-Try-On/Graphonomy-master/README.md
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Graphonomy: Universal Human Parsing via Graph Transfer Learning
|
2 |
+
|
3 |
+
This repository contains the code for the paper:
|
4 |
+
|
5 |
+
[**Graphonomy: Universal Human Parsing via Graph Transfer Learning**](https://arxiv.org/abs/1904.04536)
|
6 |
+
,Ke Gong, Yiming Gao, Xiaodan Liang, Xiaohui Shen, Meng Wang, Liang Lin.
|
7 |
+
|
8 |
+
|
9 |
+
# Environment and installation
|
10 |
+
+ Pytorch = 0.4.0
|
11 |
+
+ torchvision
|
12 |
+
+ scipy
|
13 |
+
+ tensorboardX
|
14 |
+
+ numpy
|
15 |
+
+ opencv-python
|
16 |
+
+ matplotlib
|
17 |
+
+ networkx
|
18 |
+
|
19 |
+
you can install above package by using `pip install -r requirements.txt`
|
20 |
+
|
21 |
+
# Getting Started
|
22 |
+
### Data Preparation
|
23 |
+
+ You need to download the human parsing dataset, prepare the images and store in `/data/datasets/dataset_name/`.
|
24 |
+
We recommend to symlink the path to the dataets to `/data/dataset/` as follows
|
25 |
+
|
26 |
+
```
|
27 |
+
# symlink the Pascal-Person-Part dataset for example
|
28 |
+
ln -s /path_to_Pascal_Person_Part/* data/datasets/pascal/
|
29 |
+
```
|
30 |
+
+ The file structure should look like:
|
31 |
+
```
|
32 |
+
/Graphonomy
|
33 |
+
/data
|
34 |
+
/datasets
|
35 |
+
/pascal
|
36 |
+
/JPEGImages
|
37 |
+
/list
|
38 |
+
/SegmentationPart
|
39 |
+
/CIHP_4w
|
40 |
+
/Images
|
41 |
+
/lists
|
42 |
+
...
|
43 |
+
```
|
44 |
+
+ The datasets (CIHP & ATR) are available at [google drive](https://drive.google.com/drive/folders/0BzvH3bSnp3E9ZW9paE9kdkJtM3M?usp=sharing)
|
45 |
+
and [baidu drive](http://pan.baidu.com/s/1nvqmZBN).
|
46 |
+
And you also need to download the label with flipped.
|
47 |
+
Download [cihp_flipped](https://drive.google.com/file/d/1aaJyQH-hlZEAsA7iH-mYeK1zLfQi8E2j/view?usp=sharing), unzip and store in `data/datasets/CIHP_4w/`.
|
48 |
+
Download [atr_flip](https://drive.google.com/file/d/1iR8Tn69IbDSM7gq_GG-_s11HCnhPkyG3/view?usp=sharing), unzip and store in `data/datasets/ATR/`.
|
49 |
+
|
50 |
+
### Inference
|
51 |
+
We provide a simply script to get the visualization result on the CIHP dataset using [trained](https://drive.google.com/file/d/1O9YD4kHgs3w2DUcWxtHiEFyWjCBeS_Vc/view?usp=sharing)
|
52 |
+
models as follows :
|
53 |
+
```shell
|
54 |
+
# Example of inference
|
55 |
+
python exp/inference/inference.py \
|
56 |
+
--loadmodel /path_to_inference_model \
|
57 |
+
--img_path ./img/messi.jpg \
|
58 |
+
--output_path ./img/ \
|
59 |
+
--output_name /output_file_name
|
60 |
+
```
|
61 |
+
|
62 |
+
### Training
|
63 |
+
#### Transfer learning
|
64 |
+
1. Download the Pascal pretrained model(available soon).
|
65 |
+
2. Run the `sh train_transfer_cihp.sh`.
|
66 |
+
3. The results and models are saved in exp/transfer/run/.
|
67 |
+
4. Evaluation and visualization script is eval_cihp.sh. You only need to change the attribute of `--loadmodel` before you run it.
|
68 |
+
|
69 |
+
#### Universal training
|
70 |
+
1. Download the [pretrained](https://drive.google.com/file/d/18WiffKnxaJo50sCC9zroNyHjcnTxGCbk/view?usp=sharing) model and store in /data/pretrained_model/.
|
71 |
+
2. Run the `sh train_universal.sh`.
|
72 |
+
3. The results and models are saved in exp/universal/run/.
|
73 |
+
|
74 |
+
### Testing
|
75 |
+
If you want to evaluate the performance of a pre-trained model on PASCAL-Person-Part or CIHP val/test set,
|
76 |
+
simply run the script: `sh eval_cihp/pascal.sh`.
|
77 |
+
Specify the specific model. And we provide the final model that you can download and store it in /data/pretrained_model/.
|
78 |
+
|
79 |
+
### Models
|
80 |
+
**Pascal-Person-Part trained model**
|
81 |
+
|
82 |
+
|Model|Google Cloud|Baidu Yun|
|
83 |
+
|--------|--------------|-----------|
|
84 |
+
|Graphonomy(CIHP)| [Download](https://drive.google.com/file/d/1E_V_gVDWfAJFPfe-LLu2RQaYQMdhjv9h/view?usp=sharing)| Available soon|
|
85 |
+
|
86 |
+
**CIHP trained model**
|
87 |
+
|
88 |
+
|Model|Google Cloud|Baidu Yun|
|
89 |
+
|--------|--------------|-----------|
|
90 |
+
|Graphonomy(PASCAL)| [Download](https://drive.google.com/file/d/1eUe18HoH05p0yFUd_sN6GXdTj82aW0m9/view?usp=sharing)| Available soon|
|
91 |
+
|
92 |
+
**Universal trained model**
|
93 |
+
|
94 |
+
|Model|Google Cloud|Baidu Yun|
|
95 |
+
|--------|--------------|-----------|
|
96 |
+
|Universal| [Download](https://drive.google.com/file/d/1sWJ54lCBFnzCNz5RTCGQmkVovkY9x8_D/view?usp=sharing)|Available soon|
|
97 |
+
|
98 |
+
### Todo:
|
99 |
+
- [ ] release pretrained and trained models
|
100 |
+
- [ ] update universal eval code&script
|
101 |
+
|
102 |
+
# Citation
|
103 |
+
|
104 |
+
```
|
105 |
+
@inproceedings{Gong2019Graphonomy,
|
106 |
+
author = {Ke Gong and Yiming Gao and Xiaodan Liang and Xiaohui Shen and Meng Wang and Liang Lin},
|
107 |
+
title = {Graphonomy: Universal Human Parsing via Graph Transfer Learning},
|
108 |
+
booktitle = {CVPR},
|
109 |
+
year = {2019},
|
110 |
+
}
|
111 |
+
|
112 |
+
```
|
113 |
+
|
114 |
+
# Contact
|
115 |
+
if you have any questions about this repo, please feel free to contact
|
116 |
+
[[email protected]](mailto:[email protected]).
|
117 |
+
|
118 |
+
##
|
119 |
+
|
120 |
+
## Related work
|
121 |
+
+ Self-supervised Structure-sensitive Learning [SSL](https://github.com/Engineering-Course/LIP_SSL)
|
122 |
+
+ Joint Body Parsing & Pose Estimation Network [JPPNet](https://github.com/Engineering-Course/LIP_JPPNet)
|
123 |
+
+ Instance-level Human Parsing via Part Grouping Network [PGN](https://github.com/Engineering-Course/CIHP_PGN)
|
124 |
+
+ Graphonomy: Universal Image Parsing via Graph Reasoning and Transfer [paper](https://arxiv.org/abs/2101.10620) [code](https://github.com/Gaoyiminggithub/Graphonomy-Panoptic)
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__init__.py
ADDED
File without changes
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (177 Bytes). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (153 Bytes). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-310.pyc
ADDED
Binary file (14.3 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-39.pyc
ADDED
Binary file (15 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/atr.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from .mypath_atr import Path
|
6 |
+
import random
|
7 |
+
from PIL import ImageFile
|
8 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
9 |
+
|
10 |
+
class VOCSegmentation(Dataset):
|
11 |
+
"""
|
12 |
+
ATR dataset
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
base_dir=Path.db_root_dir('atr'),
|
17 |
+
split='train',
|
18 |
+
transform=None,
|
19 |
+
flip=False,
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
:param base_dir: path to ATR dataset directory
|
23 |
+
:param split: train/val
|
24 |
+
:param transform: transform to apply
|
25 |
+
"""
|
26 |
+
super(VOCSegmentation).__init__()
|
27 |
+
self._flip_flag = flip
|
28 |
+
|
29 |
+
self._base_dir = base_dir
|
30 |
+
self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
|
31 |
+
self._cat_dir = os.path.join(self._base_dir, 'SegmentationClassAug')
|
32 |
+
self._flip_dir = os.path.join(self._base_dir,'SegmentationClassAug_rev')
|
33 |
+
|
34 |
+
if isinstance(split, str):
|
35 |
+
self.split = [split]
|
36 |
+
else:
|
37 |
+
split.sort()
|
38 |
+
self.split = split
|
39 |
+
|
40 |
+
self.transform = transform
|
41 |
+
|
42 |
+
_splits_dir = os.path.join(self._base_dir, 'list')
|
43 |
+
|
44 |
+
self.im_ids = []
|
45 |
+
self.images = []
|
46 |
+
self.categories = []
|
47 |
+
self.flip_categories = []
|
48 |
+
|
49 |
+
for splt in self.split:
|
50 |
+
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
|
51 |
+
lines = f.read().splitlines()
|
52 |
+
|
53 |
+
for ii, line in enumerate(lines):
|
54 |
+
|
55 |
+
_image = os.path.join(self._image_dir, line+'.jpg' )
|
56 |
+
_cat = os.path.join(self._cat_dir, line +'.png')
|
57 |
+
_flip = os.path.join(self._flip_dir,line + '.png')
|
58 |
+
# print(self._image_dir,_image)
|
59 |
+
assert os.path.isfile(_image)
|
60 |
+
# print(_cat)
|
61 |
+
assert os.path.isfile(_cat)
|
62 |
+
assert os.path.isfile(_flip)
|
63 |
+
self.im_ids.append(line)
|
64 |
+
self.images.append(_image)
|
65 |
+
self.categories.append(_cat)
|
66 |
+
self.flip_categories.append(_flip)
|
67 |
+
|
68 |
+
|
69 |
+
assert (len(self.images) == len(self.categories))
|
70 |
+
assert len(self.flip_categories) == len(self.categories)
|
71 |
+
|
72 |
+
# Display stats
|
73 |
+
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return len(self.images)
|
77 |
+
|
78 |
+
|
79 |
+
def __getitem__(self, index):
|
80 |
+
_img, _target= self._make_img_gt_point_pair(index)
|
81 |
+
sample = {'image': _img, 'label': _target}
|
82 |
+
|
83 |
+
if self.transform is not None:
|
84 |
+
sample = self.transform(sample)
|
85 |
+
|
86 |
+
return sample
|
87 |
+
|
88 |
+
def _make_img_gt_point_pair(self, index):
|
89 |
+
# Read Image and Target
|
90 |
+
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
|
91 |
+
# _target = np.array(Image.open(self.categories[index])).astype(np.float32)
|
92 |
+
|
93 |
+
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
|
94 |
+
if self._flip_flag:
|
95 |
+
if random.random() < 0.5:
|
96 |
+
_target = Image.open(self.flip_categories[index])
|
97 |
+
_img = _img.transpose(Image.FLIP_LEFT_RIGHT)
|
98 |
+
else:
|
99 |
+
_target = Image.open(self.categories[index])
|
100 |
+
else:
|
101 |
+
_target = Image.open(self.categories[index])
|
102 |
+
|
103 |
+
return _img, _target
|
104 |
+
|
105 |
+
def __str__(self):
|
106 |
+
return 'ATR(split=' + str(self.split) + ')'
|
107 |
+
|
108 |
+
|
109 |
+
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from .mypath_cihp import Path
|
6 |
+
import random
|
7 |
+
|
8 |
+
class VOCSegmentation(Dataset):
|
9 |
+
"""
|
10 |
+
CIHP dataset
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
base_dir=Path.db_root_dir('cihp'),
|
15 |
+
split='train',
|
16 |
+
transform=None,
|
17 |
+
flip=False,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
:param base_dir: path to CIHP dataset directory
|
21 |
+
:param split: train/val/test
|
22 |
+
:param transform: transform to apply
|
23 |
+
"""
|
24 |
+
super(VOCSegmentation).__init__()
|
25 |
+
self._flip_flag = flip
|
26 |
+
|
27 |
+
self._base_dir = base_dir
|
28 |
+
self._image_dir = os.path.join(self._base_dir, 'Images')
|
29 |
+
self._cat_dir = os.path.join(self._base_dir, 'Category_ids')
|
30 |
+
self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids')
|
31 |
+
|
32 |
+
if isinstance(split, str):
|
33 |
+
self.split = [split]
|
34 |
+
else:
|
35 |
+
split.sort()
|
36 |
+
self.split = split
|
37 |
+
|
38 |
+
self.transform = transform
|
39 |
+
|
40 |
+
_splits_dir = os.path.join(self._base_dir, 'lists')
|
41 |
+
|
42 |
+
self.im_ids = []
|
43 |
+
self.images = []
|
44 |
+
self.categories = []
|
45 |
+
self.flip_categories = []
|
46 |
+
|
47 |
+
for splt in self.split:
|
48 |
+
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
|
49 |
+
lines = f.read().splitlines()
|
50 |
+
|
51 |
+
for ii, line in enumerate(lines):
|
52 |
+
|
53 |
+
_image = os.path.join(self._image_dir, line+'.jpg' )
|
54 |
+
_cat = os.path.join(self._cat_dir, line +'.png')
|
55 |
+
_flip = os.path.join(self._flip_dir,line + '.png')
|
56 |
+
# print(self._image_dir,_image)
|
57 |
+
assert os.path.isfile(_image)
|
58 |
+
# print(_cat)
|
59 |
+
assert os.path.isfile(_cat)
|
60 |
+
assert os.path.isfile(_flip)
|
61 |
+
self.im_ids.append(line)
|
62 |
+
self.images.append(_image)
|
63 |
+
self.categories.append(_cat)
|
64 |
+
self.flip_categories.append(_flip)
|
65 |
+
|
66 |
+
|
67 |
+
assert (len(self.images) == len(self.categories))
|
68 |
+
assert len(self.flip_categories) == len(self.categories)
|
69 |
+
|
70 |
+
# Display stats
|
71 |
+
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return len(self.images)
|
75 |
+
|
76 |
+
|
77 |
+
def __getitem__(self, index):
|
78 |
+
_img, _target= self._make_img_gt_point_pair(index)
|
79 |
+
sample = {'image': _img, 'label': _target}
|
80 |
+
|
81 |
+
if self.transform is not None:
|
82 |
+
sample = self.transform(sample)
|
83 |
+
|
84 |
+
return sample
|
85 |
+
|
86 |
+
def _make_img_gt_point_pair(self, index):
|
87 |
+
# Read Image and Target
|
88 |
+
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
|
89 |
+
# _target = np.array(Image.open(self.categories[index])).astype(np.float32)
|
90 |
+
|
91 |
+
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
|
92 |
+
if self._flip_flag:
|
93 |
+
if random.random() < 0.5:
|
94 |
+
_target = Image.open(self.flip_categories[index])
|
95 |
+
_img = _img.transpose(Image.FLIP_LEFT_RIGHT)
|
96 |
+
else:
|
97 |
+
_target = Image.open(self.categories[index])
|
98 |
+
else:
|
99 |
+
_target = Image.open(self.categories[index])
|
100 |
+
|
101 |
+
return _img, _target
|
102 |
+
|
103 |
+
def __str__(self):
|
104 |
+
return 'CIHP(split=' + str(self.split) + ')'
|
105 |
+
|
106 |
+
|
107 |
+
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp_pascal_atr.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from .mypath_cihp import Path
|
7 |
+
from .mypath_pascal import Path as PP
|
8 |
+
from .mypath_atr import Path as PA
|
9 |
+
import random
|
10 |
+
from PIL import ImageFile
|
11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
12 |
+
|
13 |
+
class VOCSegmentation(Dataset):
|
14 |
+
"""
|
15 |
+
Pascal dataset
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
cihp_dir=Path.db_root_dir('cihp'),
|
20 |
+
split='train',
|
21 |
+
transform=None,
|
22 |
+
flip=False,
|
23 |
+
pascal_dir = PP.db_root_dir('pascal'),
|
24 |
+
atr_dir = PA.db_root_dir('atr'),
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
:param cihp_dir: path to CIHP dataset directory
|
28 |
+
:param pascal_dir: path to PASCAL dataset directory
|
29 |
+
:param atr_dir: path to ATR dataset directory
|
30 |
+
:param split: train/val
|
31 |
+
:param transform: transform to apply
|
32 |
+
"""
|
33 |
+
super(VOCSegmentation).__init__()
|
34 |
+
## for cihp
|
35 |
+
self._flip_flag = flip
|
36 |
+
self._base_dir = cihp_dir
|
37 |
+
self._image_dir = os.path.join(self._base_dir, 'Images')
|
38 |
+
self._cat_dir = os.path.join(self._base_dir, 'Category_ids')
|
39 |
+
self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids')
|
40 |
+
## for Pascal
|
41 |
+
self._base_dir_pascal = pascal_dir
|
42 |
+
self._image_dir_pascal = os.path.join(self._base_dir_pascal, 'JPEGImages')
|
43 |
+
self._cat_dir_pascal = os.path.join(self._base_dir_pascal, 'SegmentationPart')
|
44 |
+
# self._flip_dir_pascal = os.path.join(self._base_dir_pascal, 'Category_rev_ids')
|
45 |
+
## for atr
|
46 |
+
self._base_dir_atr = atr_dir
|
47 |
+
self._image_dir_atr = os.path.join(self._base_dir_atr, 'JPEGImages')
|
48 |
+
self._cat_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug')
|
49 |
+
self._flip_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug_rev')
|
50 |
+
|
51 |
+
if isinstance(split, str):
|
52 |
+
self.split = [split]
|
53 |
+
else:
|
54 |
+
split.sort()
|
55 |
+
self.split = split
|
56 |
+
|
57 |
+
self.transform = transform
|
58 |
+
|
59 |
+
_splits_dir = os.path.join(self._base_dir, 'lists')
|
60 |
+
_splits_dir_pascal = os.path.join(self._base_dir_pascal, 'list')
|
61 |
+
_splits_dir_atr = os.path.join(self._base_dir_atr, 'list')
|
62 |
+
|
63 |
+
self.im_ids = []
|
64 |
+
self.images = []
|
65 |
+
self.categories = []
|
66 |
+
self.flip_categories = []
|
67 |
+
self.datasets_lbl = []
|
68 |
+
|
69 |
+
# num
|
70 |
+
self.num_cihp = 0
|
71 |
+
self.num_pascal = 0
|
72 |
+
self.num_atr = 0
|
73 |
+
# for cihp is 0
|
74 |
+
for splt in self.split:
|
75 |
+
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
|
76 |
+
lines = f.read().splitlines()
|
77 |
+
self.num_cihp += len(lines)
|
78 |
+
for ii, line in enumerate(lines):
|
79 |
+
|
80 |
+
_image = os.path.join(self._image_dir, line+'.jpg' )
|
81 |
+
_cat = os.path.join(self._cat_dir, line +'.png')
|
82 |
+
_flip = os.path.join(self._flip_dir,line + '.png')
|
83 |
+
# print(self._image_dir,_image)
|
84 |
+
assert os.path.isfile(_image)
|
85 |
+
# print(_cat)
|
86 |
+
assert os.path.isfile(_cat)
|
87 |
+
assert os.path.isfile(_flip)
|
88 |
+
self.im_ids.append(line)
|
89 |
+
self.images.append(_image)
|
90 |
+
self.categories.append(_cat)
|
91 |
+
self.flip_categories.append(_flip)
|
92 |
+
self.datasets_lbl.append(0)
|
93 |
+
|
94 |
+
# for pascal is 1
|
95 |
+
for splt in self.split:
|
96 |
+
if splt == 'test':
|
97 |
+
splt='val'
|
98 |
+
with open(os.path.join(os.path.join(_splits_dir_pascal, splt + '_id.txt')), "r") as f:
|
99 |
+
lines = f.read().splitlines()
|
100 |
+
self.num_pascal += len(lines)
|
101 |
+
for ii, line in enumerate(lines):
|
102 |
+
|
103 |
+
_image = os.path.join(self._image_dir_pascal, line+'.jpg' )
|
104 |
+
_cat = os.path.join(self._cat_dir_pascal, line +'.png')
|
105 |
+
# _flip = os.path.join(self._flip_dir,line + '.png')
|
106 |
+
# print(self._image_dir,_image)
|
107 |
+
assert os.path.isfile(_image)
|
108 |
+
# print(_cat)
|
109 |
+
assert os.path.isfile(_cat)
|
110 |
+
# assert os.path.isfile(_flip)
|
111 |
+
self.im_ids.append(line)
|
112 |
+
self.images.append(_image)
|
113 |
+
self.categories.append(_cat)
|
114 |
+
self.flip_categories.append([])
|
115 |
+
self.datasets_lbl.append(1)
|
116 |
+
|
117 |
+
# for atr is 2
|
118 |
+
for splt in self.split:
|
119 |
+
with open(os.path.join(os.path.join(_splits_dir_atr, splt + '_id.txt')), "r") as f:
|
120 |
+
lines = f.read().splitlines()
|
121 |
+
self.num_atr += len(lines)
|
122 |
+
for ii, line in enumerate(lines):
|
123 |
+
_image = os.path.join(self._image_dir_atr, line + '.jpg')
|
124 |
+
_cat = os.path.join(self._cat_dir_atr, line + '.png')
|
125 |
+
_flip = os.path.join(self._flip_dir_atr, line + '.png')
|
126 |
+
# print(self._image_dir,_image)
|
127 |
+
assert os.path.isfile(_image)
|
128 |
+
# print(_cat)
|
129 |
+
assert os.path.isfile(_cat)
|
130 |
+
assert os.path.isfile(_flip)
|
131 |
+
self.im_ids.append(line)
|
132 |
+
self.images.append(_image)
|
133 |
+
self.categories.append(_cat)
|
134 |
+
self.flip_categories.append(_flip)
|
135 |
+
self.datasets_lbl.append(2)
|
136 |
+
|
137 |
+
assert (len(self.images) == len(self.categories))
|
138 |
+
# assert len(self.flip_categories) == len(self.categories)
|
139 |
+
|
140 |
+
# Display stats
|
141 |
+
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
142 |
+
|
143 |
+
def __len__(self):
|
144 |
+
return len(self.images)
|
145 |
+
|
146 |
+
def get_class_num(self):
|
147 |
+
return self.num_cihp,self.num_pascal,self.num_atr
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def __getitem__(self, index):
|
152 |
+
_img, _target,_lbl= self._make_img_gt_point_pair(index)
|
153 |
+
sample = {'image': _img, 'label': _target,}
|
154 |
+
|
155 |
+
if self.transform is not None:
|
156 |
+
sample = self.transform(sample)
|
157 |
+
sample['pascal'] = _lbl
|
158 |
+
return sample
|
159 |
+
|
160 |
+
def _make_img_gt_point_pair(self, index):
|
161 |
+
# Read Image and Target
|
162 |
+
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
|
163 |
+
# _target = np.array(Image.open(self.categories[index])).astype(np.float32)
|
164 |
+
|
165 |
+
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
|
166 |
+
type_lbl = self.datasets_lbl[index]
|
167 |
+
if self._flip_flag:
|
168 |
+
if random.random() < 0.5 :
|
169 |
+
# _target = Image.open(self.flip_categories[index])
|
170 |
+
_img = _img.transpose(Image.FLIP_LEFT_RIGHT)
|
171 |
+
if type_lbl == 0 or type_lbl == 2:
|
172 |
+
_target = Image.open(self.flip_categories[index])
|
173 |
+
else:
|
174 |
+
_target = Image.open(self.categories[index])
|
175 |
+
_target = _target.transpose(Image.FLIP_LEFT_RIGHT)
|
176 |
+
else:
|
177 |
+
_target = Image.open(self.categories[index])
|
178 |
+
else:
|
179 |
+
_target = Image.open(self.categories[index])
|
180 |
+
|
181 |
+
return _img, _target,type_lbl
|
182 |
+
|
183 |
+
def __str__(self):
|
184 |
+
return 'datasets(split=' + str(self.split) + ')'
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == '__main__':
|
198 |
+
from dataloaders import custom_transforms as tr
|
199 |
+
from dataloaders.utils import decode_segmap
|
200 |
+
from torch.utils.data import DataLoader
|
201 |
+
from torchvision import transforms
|
202 |
+
import matplotlib.pyplot as plt
|
203 |
+
|
204 |
+
composed_transforms_tr = transforms.Compose([
|
205 |
+
# tr.RandomHorizontalFlip(),
|
206 |
+
tr.RandomSized_new(512),
|
207 |
+
tr.RandomRotate(15),
|
208 |
+
tr.ToTensor_()])
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
voc_train = VOCSegmentation(split='train',
|
213 |
+
transform=composed_transforms_tr)
|
214 |
+
|
215 |
+
dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=1)
|
216 |
+
|
217 |
+
for ii, sample in enumerate(dataloader):
|
218 |
+
if ii >10:
|
219 |
+
break
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/custom_transforms.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numbers
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from PIL import Image, ImageOps
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
class RandomCrop(object):
|
11 |
+
def __init__(self, size, padding=0):
|
12 |
+
if isinstance(size, numbers.Number):
|
13 |
+
self.size = (int(size), int(size))
|
14 |
+
else:
|
15 |
+
self.size = size # h, w
|
16 |
+
self.padding = padding
|
17 |
+
|
18 |
+
def __call__(self, sample):
|
19 |
+
img, mask = sample['image'], sample['label']
|
20 |
+
|
21 |
+
if self.padding > 0:
|
22 |
+
img = ImageOps.expand(img, border=self.padding, fill=0)
|
23 |
+
mask = ImageOps.expand(mask, border=self.padding, fill=0)
|
24 |
+
|
25 |
+
assert img.size == mask.size
|
26 |
+
w, h = img.size
|
27 |
+
th, tw = self.size # target size
|
28 |
+
if w == tw and h == th:
|
29 |
+
return {'image': img,
|
30 |
+
'label': mask}
|
31 |
+
if w < tw or h < th:
|
32 |
+
img = img.resize((tw, th), Image.BILINEAR)
|
33 |
+
mask = mask.resize((tw, th), Image.NEAREST)
|
34 |
+
return {'image': img,
|
35 |
+
'label': mask}
|
36 |
+
|
37 |
+
x1 = random.randint(0, w - tw)
|
38 |
+
y1 = random.randint(0, h - th)
|
39 |
+
img = img.crop((x1, y1, x1 + tw, y1 + th))
|
40 |
+
mask = mask.crop((x1, y1, x1 + tw, y1 + th))
|
41 |
+
|
42 |
+
return {'image': img,
|
43 |
+
'label': mask}
|
44 |
+
|
45 |
+
class RandomCrop_new(object):
|
46 |
+
def __init__(self, size, padding=0):
|
47 |
+
if isinstance(size, numbers.Number):
|
48 |
+
self.size = (int(size), int(size))
|
49 |
+
else:
|
50 |
+
self.size = size # h, w
|
51 |
+
self.padding = padding
|
52 |
+
|
53 |
+
def __call__(self, sample):
|
54 |
+
img, mask = sample['image'], sample['label']
|
55 |
+
|
56 |
+
if self.padding > 0:
|
57 |
+
img = ImageOps.expand(img, border=self.padding, fill=0)
|
58 |
+
mask = ImageOps.expand(mask, border=self.padding, fill=0)
|
59 |
+
|
60 |
+
assert img.size == mask.size
|
61 |
+
w, h = img.size
|
62 |
+
th, tw = self.size # target size
|
63 |
+
if w == tw and h == th:
|
64 |
+
return {'image': img,
|
65 |
+
'label': mask}
|
66 |
+
|
67 |
+
new_img = Image.new('RGB',(tw,th),'black') # size is w x h; and 'white' is 255
|
68 |
+
new_mask = Image.new('L',(tw,th),'white') # same above
|
69 |
+
|
70 |
+
# if w > tw or h > th
|
71 |
+
x1 = y1 = 0
|
72 |
+
if w > tw:
|
73 |
+
x1 = random.randint(0,w - tw)
|
74 |
+
if h > th:
|
75 |
+
y1 = random.randint(0,h - th)
|
76 |
+
# crop
|
77 |
+
img = img.crop((x1,y1, x1 + tw, y1 + th))
|
78 |
+
mask = mask.crop((x1,y1, x1 + tw, y1 + th))
|
79 |
+
new_img.paste(img,(0,0))
|
80 |
+
new_mask.paste(mask,(0,0))
|
81 |
+
|
82 |
+
# x1 = random.randint(0, w - tw)
|
83 |
+
# y1 = random.randint(0, h - th)
|
84 |
+
# img = img.crop((x1, y1, x1 + tw, y1 + th))
|
85 |
+
# mask = mask.crop((x1, y1, x1 + tw, y1 + th))
|
86 |
+
|
87 |
+
return {'image': new_img,
|
88 |
+
'label': new_mask}
|
89 |
+
|
90 |
+
class Paste(object):
|
91 |
+
def __init__(self, size,):
|
92 |
+
if isinstance(size, numbers.Number):
|
93 |
+
self.size = (int(size), int(size))
|
94 |
+
else:
|
95 |
+
self.size = size # h, w
|
96 |
+
|
97 |
+
def __call__(self, sample):
|
98 |
+
img, mask = sample['image'], sample['label']
|
99 |
+
|
100 |
+
assert img.size == mask.size
|
101 |
+
w, h = img.size
|
102 |
+
th, tw = self.size # target size
|
103 |
+
assert (w <=tw) and (h <= th)
|
104 |
+
if w == tw and h == th:
|
105 |
+
return {'image': img,
|
106 |
+
'label': mask}
|
107 |
+
|
108 |
+
new_img = Image.new('RGB',(tw,th),'black') # size is w x h; and 'white' is 255
|
109 |
+
new_mask = Image.new('L',(tw,th),'white') # same above
|
110 |
+
|
111 |
+
new_img.paste(img,(0,0))
|
112 |
+
new_mask.paste(mask,(0,0))
|
113 |
+
|
114 |
+
return {'image': new_img,
|
115 |
+
'label': new_mask}
|
116 |
+
|
117 |
+
class CenterCrop(object):
|
118 |
+
def __init__(self, size):
|
119 |
+
if isinstance(size, numbers.Number):
|
120 |
+
self.size = (int(size), int(size))
|
121 |
+
else:
|
122 |
+
self.size = size
|
123 |
+
|
124 |
+
def __call__(self, sample):
|
125 |
+
img = sample['image']
|
126 |
+
mask = sample['label']
|
127 |
+
assert img.size == mask.size
|
128 |
+
w, h = img.size
|
129 |
+
th, tw = self.size
|
130 |
+
x1 = int(round((w - tw) / 2.))
|
131 |
+
y1 = int(round((h - th) / 2.))
|
132 |
+
img = img.crop((x1, y1, x1 + tw, y1 + th))
|
133 |
+
mask = mask.crop((x1, y1, x1 + tw, y1 + th))
|
134 |
+
|
135 |
+
return {'image': img,
|
136 |
+
'label': mask}
|
137 |
+
|
138 |
+
class RandomHorizontalFlip(object):
|
139 |
+
def __call__(self, sample):
|
140 |
+
img = sample['image']
|
141 |
+
mask = sample['label']
|
142 |
+
if random.random() < 0.5:
|
143 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
144 |
+
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
145 |
+
|
146 |
+
return {'image': img,
|
147 |
+
'label': mask}
|
148 |
+
|
149 |
+
class HorizontalFlip(object):
|
150 |
+
def __call__(self, sample):
|
151 |
+
img = sample['image']
|
152 |
+
mask = sample['label']
|
153 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
154 |
+
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
155 |
+
|
156 |
+
return {'image': img,
|
157 |
+
'label': mask}
|
158 |
+
|
159 |
+
class HorizontalFlip_only_img(object):
|
160 |
+
def __call__(self, sample):
|
161 |
+
img = sample['image']
|
162 |
+
mask = sample['label']
|
163 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
164 |
+
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
165 |
+
|
166 |
+
return {'image': img,
|
167 |
+
'label': mask}
|
168 |
+
|
169 |
+
class RandomHorizontalFlip_cihp(object):
|
170 |
+
def __call__(self, sample):
|
171 |
+
img = sample['image']
|
172 |
+
mask = sample['label']
|
173 |
+
if random.random() < 0.5:
|
174 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
175 |
+
# mask = Image.open()
|
176 |
+
|
177 |
+
return {'image': img,
|
178 |
+
'label': mask}
|
179 |
+
|
180 |
+
class Normalize(object):
|
181 |
+
"""Normalize a tensor image with mean and standard deviation.
|
182 |
+
Args:
|
183 |
+
mean (tuple): means for each channel.
|
184 |
+
std (tuple): standard deviations for each channel.
|
185 |
+
"""
|
186 |
+
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
|
187 |
+
self.mean = mean
|
188 |
+
self.std = std
|
189 |
+
|
190 |
+
def __call__(self, sample):
|
191 |
+
img = np.array(sample['image']).astype(np.float32)
|
192 |
+
mask = np.array(sample['label']).astype(np.float32)
|
193 |
+
img /= 255.0
|
194 |
+
img -= self.mean
|
195 |
+
img /= self.std
|
196 |
+
|
197 |
+
return {'image': img,
|
198 |
+
'label': mask}
|
199 |
+
|
200 |
+
class Normalize_255(object):
|
201 |
+
"""Normalize a tensor image with mean and standard deviation. tf use 255.
|
202 |
+
Args:
|
203 |
+
mean (tuple): means for each channel.
|
204 |
+
std (tuple): standard deviations for each channel.
|
205 |
+
"""
|
206 |
+
def __init__(self, mean=(123.15, 115.90, 103.06), std=(1., 1., 1.)):
|
207 |
+
self.mean = mean
|
208 |
+
self.std = std
|
209 |
+
|
210 |
+
def __call__(self, sample):
|
211 |
+
img = np.array(sample['image']).astype(np.float32)
|
212 |
+
mask = np.array(sample['label']).astype(np.float32)
|
213 |
+
# img = 255.0
|
214 |
+
img -= self.mean
|
215 |
+
img /= self.std
|
216 |
+
img = img
|
217 |
+
img = img[[0,3,2,1],...]
|
218 |
+
return {'image': img,
|
219 |
+
'label': mask}
|
220 |
+
|
221 |
+
class Normalize_xception_tf(object):
|
222 |
+
# def __init__(self):
|
223 |
+
# self.rgb2bgr =
|
224 |
+
|
225 |
+
def __call__(self, sample):
|
226 |
+
img = np.array(sample['image']).astype(np.float32)
|
227 |
+
mask = np.array(sample['label']).astype(np.float32)
|
228 |
+
img = (img*2.0)/255.0 - 1
|
229 |
+
# print(img.shape)
|
230 |
+
# img = img[[0,3,2,1],...]
|
231 |
+
return {'image': img,
|
232 |
+
'label': mask}
|
233 |
+
|
234 |
+
class Normalize_xception_tf_only_img(object):
|
235 |
+
# def __init__(self):
|
236 |
+
# self.rgb2bgr =
|
237 |
+
|
238 |
+
def __call__(self, sample):
|
239 |
+
img = np.array(sample['image']).astype(np.float32)
|
240 |
+
# mask = np.array(sample['label']).astype(np.float32)
|
241 |
+
img = (img*2.0)/255.0 - 1
|
242 |
+
# print(img.shape)
|
243 |
+
# img = img[[0,3,2,1],...]
|
244 |
+
return {'image': img,
|
245 |
+
'label': sample['label']}
|
246 |
+
|
247 |
+
class Normalize_cityscapes(object):
|
248 |
+
"""Normalize a tensor image with mean and standard deviation.
|
249 |
+
Args:
|
250 |
+
mean (tuple): means for each channel.
|
251 |
+
std (tuple): standard deviations for each channel.
|
252 |
+
"""
|
253 |
+
def __init__(self, mean=(0., 0., 0.)):
|
254 |
+
self.mean = mean
|
255 |
+
|
256 |
+
def __call__(self, sample):
|
257 |
+
img = np.array(sample['image']).astype(np.float32)
|
258 |
+
mask = np.array(sample['label']).astype(np.float32)
|
259 |
+
img -= self.mean
|
260 |
+
img /= 255.0
|
261 |
+
|
262 |
+
return {'image': img,
|
263 |
+
'label': mask}
|
264 |
+
|
265 |
+
class ToTensor_(object):
|
266 |
+
"""Convert ndarrays in sample to Tensors."""
|
267 |
+
def __init__(self):
|
268 |
+
self.rgb2bgr = transforms.Lambda(lambda x:x[[2,1,0],...])
|
269 |
+
|
270 |
+
def __call__(self, sample):
|
271 |
+
# swap color axis because
|
272 |
+
# numpy image: H x W x C
|
273 |
+
# torch image: C X H X W
|
274 |
+
img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
|
275 |
+
mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
|
276 |
+
# mask[mask == 255] = 0
|
277 |
+
|
278 |
+
img = torch.from_numpy(img).float()
|
279 |
+
img = self.rgb2bgr(img)
|
280 |
+
mask = torch.from_numpy(mask).float()
|
281 |
+
|
282 |
+
|
283 |
+
return {'image': img,
|
284 |
+
'label': mask}
|
285 |
+
|
286 |
+
class ToTensor_only_img(object):
|
287 |
+
"""Convert ndarrays in sample to Tensors."""
|
288 |
+
def __init__(self):
|
289 |
+
self.rgb2bgr = transforms.Lambda(lambda x:x[[2,1,0],...])
|
290 |
+
|
291 |
+
def __call__(self, sample):
|
292 |
+
# swap color axis because
|
293 |
+
# numpy image: H x W x C
|
294 |
+
# torch image: C X H X W
|
295 |
+
img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
|
296 |
+
# mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
|
297 |
+
# mask[mask == 255] = 0
|
298 |
+
|
299 |
+
img = torch.from_numpy(img).float()
|
300 |
+
img = self.rgb2bgr(img)
|
301 |
+
# mask = torch.from_numpy(mask).float()
|
302 |
+
|
303 |
+
|
304 |
+
return {'image': img,
|
305 |
+
'label': sample['label']}
|
306 |
+
|
307 |
+
class FixedResize(object):
|
308 |
+
def __init__(self, size):
|
309 |
+
self.size = tuple(reversed(size)) # size: (h, w)
|
310 |
+
|
311 |
+
def __call__(self, sample):
|
312 |
+
img = sample['image']
|
313 |
+
mask = sample['label']
|
314 |
+
|
315 |
+
assert img.size == mask.size
|
316 |
+
|
317 |
+
img = img.resize(self.size, Image.BILINEAR)
|
318 |
+
mask = mask.resize(self.size, Image.NEAREST)
|
319 |
+
|
320 |
+
return {'image': img,
|
321 |
+
'label': mask}
|
322 |
+
|
323 |
+
class Keep_origin_size_Resize(object):
|
324 |
+
def __init__(self, max_size, scale=1.0):
|
325 |
+
self.size = tuple(reversed(max_size)) # size: (h, w)
|
326 |
+
self.scale = scale
|
327 |
+
self.paste = Paste(int(max_size[0]*scale))
|
328 |
+
|
329 |
+
def __call__(self, sample):
|
330 |
+
img = sample['image']
|
331 |
+
mask = sample['label']
|
332 |
+
|
333 |
+
assert img.size == mask.size
|
334 |
+
h, w = self.size
|
335 |
+
h = int(h*self.scale)
|
336 |
+
w = int(w*self.scale)
|
337 |
+
img = img.resize((h, w), Image.BILINEAR)
|
338 |
+
mask = mask.resize((h, w), Image.NEAREST)
|
339 |
+
|
340 |
+
return self.paste({'image': img,
|
341 |
+
'label': mask})
|
342 |
+
|
343 |
+
class Scale(object):
|
344 |
+
def __init__(self, size):
|
345 |
+
if isinstance(size, numbers.Number):
|
346 |
+
self.size = (int(size), int(size))
|
347 |
+
else:
|
348 |
+
self.size = size
|
349 |
+
|
350 |
+
def __call__(self, sample):
|
351 |
+
img = sample['image']
|
352 |
+
mask = sample['label']
|
353 |
+
assert img.size == mask.size
|
354 |
+
w, h = img.size
|
355 |
+
|
356 |
+
if (w >= h and w == self.size[1]) or (h >= w and h == self.size[0]):
|
357 |
+
return {'image': img,
|
358 |
+
'label': mask}
|
359 |
+
oh, ow = self.size
|
360 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
361 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
362 |
+
|
363 |
+
return {'image': img,
|
364 |
+
'label': mask}
|
365 |
+
|
366 |
+
class Scale_(object):
|
367 |
+
def __init__(self, scale):
|
368 |
+
self.scale = scale
|
369 |
+
|
370 |
+
def __call__(self, sample):
|
371 |
+
img = sample['image']
|
372 |
+
mask = sample['label']
|
373 |
+
assert img.size == mask.size
|
374 |
+
w, h = img.size
|
375 |
+
ow = int(w*self.scale)
|
376 |
+
oh = int(h*self.scale)
|
377 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
378 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
379 |
+
|
380 |
+
return {'image': img,
|
381 |
+
'label': mask}
|
382 |
+
|
383 |
+
class Scale_only_img(object):
|
384 |
+
def __init__(self, scale):
|
385 |
+
self.scale = scale
|
386 |
+
|
387 |
+
def __call__(self, sample):
|
388 |
+
img = sample['image']
|
389 |
+
mask = sample['label']
|
390 |
+
# assert img.size == mask.size
|
391 |
+
w, h = img.size
|
392 |
+
ow = int(w*self.scale)
|
393 |
+
oh = int(h*self.scale)
|
394 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
395 |
+
# mask = mask.resize((ow, oh), Image.NEAREST)
|
396 |
+
|
397 |
+
return {'image': img,
|
398 |
+
'label': mask}
|
399 |
+
|
400 |
+
class RandomSizedCrop(object):
|
401 |
+
def __init__(self, size):
|
402 |
+
self.size = size
|
403 |
+
|
404 |
+
def __call__(self, sample):
|
405 |
+
img = sample['image']
|
406 |
+
mask = sample['label']
|
407 |
+
assert img.size == mask.size
|
408 |
+
for attempt in range(10):
|
409 |
+
area = img.size[0] * img.size[1]
|
410 |
+
target_area = random.uniform(0.45, 1.0) * area
|
411 |
+
aspect_ratio = random.uniform(0.5, 2)
|
412 |
+
|
413 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
414 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
415 |
+
|
416 |
+
if random.random() < 0.5:
|
417 |
+
w, h = h, w
|
418 |
+
|
419 |
+
if w <= img.size[0] and h <= img.size[1]:
|
420 |
+
x1 = random.randint(0, img.size[0] - w)
|
421 |
+
y1 = random.randint(0, img.size[1] - h)
|
422 |
+
|
423 |
+
img = img.crop((x1, y1, x1 + w, y1 + h))
|
424 |
+
mask = mask.crop((x1, y1, x1 + w, y1 + h))
|
425 |
+
assert (img.size == (w, h))
|
426 |
+
|
427 |
+
img = img.resize((self.size, self.size), Image.BILINEAR)
|
428 |
+
mask = mask.resize((self.size, self.size), Image.NEAREST)
|
429 |
+
|
430 |
+
return {'image': img,
|
431 |
+
'label': mask}
|
432 |
+
|
433 |
+
# Fallback
|
434 |
+
scale = Scale(self.size)
|
435 |
+
crop = CenterCrop(self.size)
|
436 |
+
sample = crop(scale(sample))
|
437 |
+
return sample
|
438 |
+
|
439 |
+
class RandomRotate(object):
|
440 |
+
def __init__(self, degree):
|
441 |
+
self.degree = degree
|
442 |
+
|
443 |
+
def __call__(self, sample):
|
444 |
+
img = sample['image']
|
445 |
+
mask = sample['label']
|
446 |
+
rotate_degree = random.random() * 2 * self.degree - self.degree
|
447 |
+
img = img.rotate(rotate_degree, Image.BILINEAR)
|
448 |
+
mask = mask.rotate(rotate_degree, Image.NEAREST)
|
449 |
+
|
450 |
+
return {'image': img,
|
451 |
+
'label': mask}
|
452 |
+
|
453 |
+
class RandomSized_new(object):
|
454 |
+
'''what we use is this class to aug'''
|
455 |
+
def __init__(self, size,scale1=0.5,scale2=2):
|
456 |
+
self.size = size
|
457 |
+
# self.scale = Scale(self.size)
|
458 |
+
self.crop = RandomCrop_new(self.size)
|
459 |
+
self.small_scale = scale1
|
460 |
+
self.big_scale = scale2
|
461 |
+
|
462 |
+
def __call__(self, sample):
|
463 |
+
img = sample['image']
|
464 |
+
mask = sample['label']
|
465 |
+
assert img.size == mask.size
|
466 |
+
|
467 |
+
w = int(random.uniform(self.small_scale, self.big_scale) * img.size[0])
|
468 |
+
h = int(random.uniform(self.small_scale, self.big_scale) * img.size[1])
|
469 |
+
|
470 |
+
img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
|
471 |
+
sample = {'image': img, 'label': mask}
|
472 |
+
# finish resize
|
473 |
+
return self.crop(sample)
|
474 |
+
# class Random
|
475 |
+
|
476 |
+
class RandomScale(object):
|
477 |
+
def __init__(self, limit):
|
478 |
+
self.limit = limit
|
479 |
+
|
480 |
+
def __call__(self, sample):
|
481 |
+
img = sample['image']
|
482 |
+
mask = sample['label']
|
483 |
+
assert img.size == mask.size
|
484 |
+
|
485 |
+
scale = random.uniform(self.limit[0], self.limit[1])
|
486 |
+
w = int(scale * img.size[0])
|
487 |
+
h = int(scale * img.size[1])
|
488 |
+
|
489 |
+
img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
|
490 |
+
|
491 |
+
return {'image': img, 'label': mask}
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_atr.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Path(object):
|
2 |
+
@staticmethod
|
3 |
+
def db_root_dir(database):
|
4 |
+
if database == 'atr':
|
5 |
+
return './data/datasets/ATR/' # folder that contains atr/.
|
6 |
+
else:
|
7 |
+
print('Database {} not available.'.format(database))
|
8 |
+
raise NotImplementedError
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_cihp.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Path(object):
|
2 |
+
@staticmethod
|
3 |
+
def db_root_dir(database):
|
4 |
+
if database == 'cihp':
|
5 |
+
return './data/datasets/CIHP_4w/'
|
6 |
+
else:
|
7 |
+
print('Database {} not available.'.format(database))
|
8 |
+
raise NotImplementedError
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_pascal.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Path(object):
|
2 |
+
@staticmethod
|
3 |
+
def db_root_dir(database):
|
4 |
+
if database == 'pascal':
|
5 |
+
return './data/datasets/pascal/' # folder that contains pascal/.
|
6 |
+
else:
|
7 |
+
print('Database {} not available.'.format(database))
|
8 |
+
raise NotImplementedError
|
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/pascal.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from .mypath_pascal import Path
|
6 |
+
|
7 |
+
class VOCSegmentation(Dataset):
|
8 |
+
"""
|
9 |
+
Pascal dataset
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
base_dir=Path.db_root_dir('pascal'),
|
14 |
+
split='train',
|
15 |
+
transform=None
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
:param base_dir: path to PASCAL dataset directory
|
19 |
+
:param split: train/val
|
20 |
+
:param transform: transform to apply
|
21 |
+
"""
|
22 |
+
super(VOCSegmentation).__init__()
|
23 |
+
self._base_dir = base_dir
|
24 |
+
self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
|
25 |
+
self._cat_dir = os.path.join(self._base_dir, 'SegmentationPart')
|
26 |
+
|
27 |
+
if isinstance(split, str):
|
28 |
+
self.split = [split]
|
29 |
+
else:
|
30 |
+
split.sort()
|
31 |
+
self.split = split
|
32 |
+
|
33 |
+
self.transform = transform
|
34 |
+
|
35 |
+
_splits_dir = os.path.join(self._base_dir, 'list')
|
36 |
+
|
37 |
+
self.im_ids = []
|
38 |
+
self.images = []
|
39 |
+
self.categories = []
|
40 |
+
|
41 |
+
for splt in self.split:
|
42 |
+
with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
|
43 |
+
lines = f.read().splitlines()
|
44 |
+
|
45 |
+
for ii, line in enumerate(lines):
|
46 |
+
|
47 |
+
_image = os.path.join(self._image_dir, line+'.jpg' )
|
48 |
+
_cat = os.path.join(self._cat_dir, line +'.png')
|
49 |
+
# print(self._image_dir,_image)
|
50 |
+
assert os.path.isfile(_image)
|
51 |
+
# print(_cat)
|
52 |
+
assert os.path.isfile(_cat)
|
53 |
+
self.im_ids.append(line)
|
54 |
+
self.images.append(_image)
|
55 |
+
self.categories.append(_cat)
|
56 |
+
|
57 |
+
assert (len(self.images) == len(self.categories))
|
58 |
+
|
59 |
+
# Display stats
|
60 |
+
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
61 |
+
|
62 |
+
def __len__(self):
|
63 |
+
return len(self.images)
|
64 |
+
|
65 |
+
|
66 |
+
def __getitem__(self, index):
|
67 |
+
_img, _target= self._make_img_gt_point_pair(index)
|
68 |
+
sample = {'image': _img, 'label': _target}
|
69 |
+
|
70 |
+
if self.transform is not None:
|
71 |
+
sample = self.transform(sample)
|
72 |
+
|
73 |
+
return sample
|
74 |
+
|
75 |
+
def _make_img_gt_point_pair(self, index):
|
76 |
+
# Read Image and Target
|
77 |
+
# _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
|
78 |
+
# _target = np.array(Image.open(self.categories[index])).astype(np.float32)
|
79 |
+
|
80 |
+
_img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
|
81 |
+
_target = Image.open(self.categories[index])
|
82 |
+
|
83 |
+
return _img, _target
|
84 |
+
|
85 |
+
def __str__(self):
|
86 |
+
return 'PASCAL(split=' + str(self.split) + ')'
|
87 |
+
|
88 |
+
class test_segmentation(VOCSegmentation):
|
89 |
+
def __init__(self,base_dir=Path.db_root_dir('pascal'),
|
90 |
+
split='train',
|
91 |
+
transform=None,
|
92 |
+
flip=True):
|
93 |
+
super(test_segmentation, self).__init__(base_dir=base_dir,split=split,transform=transform)
|
94 |
+
self._flip_flag = flip
|
95 |
+
|
96 |
+
def __getitem__(self, index):
|
97 |
+
_img, _target= self._make_img_gt_point_pair(index)
|
98 |
+
sample = {'image': _img, 'label': _target}
|
99 |
+
|
100 |
+
if self.transform is not None:
|
101 |
+
sample = self.transform(sample)
|
102 |
+
|
103 |
+
return sample
|
104 |
+
|
105 |
+
|
106 |
+
|
TryYours-Virtual-Try-On/Graphonomy-master/eval_cihp.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python ./exp/test/eval_show_pascal2cihp.py \
|
2 |
+
--batch 1 --gpus 1 --classes 20 \
|
3 |
+
--gt_path './data/datasets/CIHP_4w/Category_ids' \
|
4 |
+
--txt_file './data/datasets/CIHP_4w/lists/val_id.txt' \
|
5 |
+
--loadmodel './data/pretrained_model/inference.pth'
|
TryYours-Virtual-Try-On/Graphonomy-master/eval_pascal.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python ./exp/test/eval_show_pascal.py \
|
2 |
+
--batch 1 --gpus 1 --classes 7 \
|
3 |
+
--gt_path './data/datasets/pascal/SegmentationPart/' \
|
4 |
+
--txt_file './data/datasets/pascal/list/val_id.txt' \
|
5 |
+
--loadmodel './cihp2pascal.pth'
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/.ipynb_checkpoints/inference-checkpoint.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import timeit
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from datetime import datetime
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from collections import OrderedDict
|
9 |
+
sys.path.append('./')
|
10 |
+
# PyTorch includes
|
11 |
+
import torch
|
12 |
+
from torch.autograd import Variable
|
13 |
+
from torchvision import transforms
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
|
17 |
+
# Custom includes
|
18 |
+
from networks import deeplab_xception_transfer, graph
|
19 |
+
from dataloaders import custom_transforms as tr
|
20 |
+
|
21 |
+
#
|
22 |
+
import argparse
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
label_colours = [(0,0,0)
|
26 |
+
, (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0)
|
27 |
+
, (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
|
28 |
+
|
29 |
+
|
30 |
+
def flip(x, dim):
|
31 |
+
indices = [slice(None)] * x.dim()
|
32 |
+
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
|
33 |
+
dtype=torch.long, device=x.device)
|
34 |
+
return x[tuple(indices)]
|
35 |
+
|
36 |
+
def flip_cihp(tail_list):
|
37 |
+
'''
|
38 |
+
|
39 |
+
:param tail_list: tail_list size is 1 x n_class x h x w
|
40 |
+
:return:
|
41 |
+
'''
|
42 |
+
# tail_list = tail_list[0]
|
43 |
+
tail_list_rev = [None] * 20
|
44 |
+
for xx in range(14):
|
45 |
+
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
|
46 |
+
tail_list_rev[14] = tail_list[15].unsqueeze(0)
|
47 |
+
tail_list_rev[15] = tail_list[14].unsqueeze(0)
|
48 |
+
tail_list_rev[16] = tail_list[17].unsqueeze(0)
|
49 |
+
tail_list_rev[17] = tail_list[16].unsqueeze(0)
|
50 |
+
tail_list_rev[18] = tail_list[19].unsqueeze(0)
|
51 |
+
tail_list_rev[19] = tail_list[18].unsqueeze(0)
|
52 |
+
return torch.cat(tail_list_rev,dim=0)
|
53 |
+
|
54 |
+
|
55 |
+
def decode_labels(mask, num_images=1, num_classes=20):
|
56 |
+
"""Decode batch of segmentation masks.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
mask: result of inference after taking argmax.
|
60 |
+
num_images: number of images to decode from the batch.
|
61 |
+
num_classes: number of classes to predict (including background).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
A batch with num_images RGB images of the same size as the input.
|
65 |
+
"""
|
66 |
+
n, h, w = mask.shape
|
67 |
+
assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (
|
68 |
+
n, num_images)
|
69 |
+
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
|
70 |
+
for i in range(num_images):
|
71 |
+
img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
|
72 |
+
pixels = img.load()
|
73 |
+
for j_, j in enumerate(mask[i, :, :]):
|
74 |
+
for k_, k in enumerate(j):
|
75 |
+
if k < num_classes:
|
76 |
+
pixels[k_, j_] = label_colours[k]
|
77 |
+
outputs[i] = np.array(img)
|
78 |
+
return outputs
|
79 |
+
|
80 |
+
def read_img(img_path):
|
81 |
+
_img = Image.open(img_path).convert('RGB') # return is RGB pic
|
82 |
+
return _img
|
83 |
+
|
84 |
+
def img_transform(img, transform=None):
|
85 |
+
sample = {'image': img, 'label': 0}
|
86 |
+
|
87 |
+
sample = transform(sample)
|
88 |
+
return sample
|
89 |
+
|
90 |
+
def inference(net, img_path='', output_path='./', output_name='f', use_gpu=True):
|
91 |
+
'''
|
92 |
+
|
93 |
+
:param net:
|
94 |
+
:param img_path:
|
95 |
+
:param output_path:
|
96 |
+
:return:
|
97 |
+
'''
|
98 |
+
# adj
|
99 |
+
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
|
100 |
+
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3)
|
101 |
+
|
102 |
+
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
|
103 |
+
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
|
104 |
+
|
105 |
+
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
|
106 |
+
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
|
107 |
+
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
|
108 |
+
|
109 |
+
# multi-scale
|
110 |
+
scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
|
111 |
+
img = read_img(img_path)
|
112 |
+
testloader_list = []
|
113 |
+
testloader_flip_list = []
|
114 |
+
for pv in scale_list:
|
115 |
+
composed_transforms_ts = transforms.Compose([
|
116 |
+
tr.Scale_only_img(pv),
|
117 |
+
tr.Normalize_xception_tf_only_img(),
|
118 |
+
tr.ToTensor_only_img()])
|
119 |
+
|
120 |
+
composed_transforms_ts_flip = transforms.Compose([
|
121 |
+
tr.Scale_only_img(pv),
|
122 |
+
tr.HorizontalFlip_only_img(),
|
123 |
+
tr.Normalize_xception_tf_only_img(),
|
124 |
+
tr.ToTensor_only_img()])
|
125 |
+
|
126 |
+
testloader_list.append(img_transform(img, composed_transforms_ts))
|
127 |
+
# print(img_transform(img, composed_transforms_ts))
|
128 |
+
testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip))
|
129 |
+
# print(testloader_list)
|
130 |
+
start_time = timeit.default_timer()
|
131 |
+
# One testing epoch
|
132 |
+
net.eval()
|
133 |
+
# 1 0.5 0.75 1.25 1.5 1.75 ; flip:
|
134 |
+
|
135 |
+
for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)):
|
136 |
+
inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
|
137 |
+
inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
|
138 |
+
inputs = inputs.unsqueeze(0)
|
139 |
+
inputs_f = inputs_f.unsqueeze(0)
|
140 |
+
inputs = torch.cat((inputs, inputs_f), dim=0)
|
141 |
+
if iii == 0:
|
142 |
+
_, _, h, w = inputs.size()
|
143 |
+
# assert inputs.size() == inputs_f.size()
|
144 |
+
|
145 |
+
# Forward pass of the mini-batch
|
146 |
+
inputs = Variable(inputs, requires_grad=False)
|
147 |
+
|
148 |
+
with torch.no_grad():
|
149 |
+
if use_gpu >= 0:
|
150 |
+
inputs = inputs.cuda()
|
151 |
+
# outputs = net.forward(inputs)
|
152 |
+
outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
|
153 |
+
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
|
154 |
+
outputs = outputs.unsqueeze(0)
|
155 |
+
|
156 |
+
if iii > 0:
|
157 |
+
outputs = F.upsample(outputs, size=(h, w), mode='bilinear', align_corners=True)
|
158 |
+
outputs_final = outputs_final + outputs
|
159 |
+
else:
|
160 |
+
outputs_final = outputs.clone()
|
161 |
+
################ plot pic
|
162 |
+
predictions = torch.max(outputs_final, 1)[1]
|
163 |
+
results = predictions.cpu().numpy()
|
164 |
+
vis_res = decode_labels(results)
|
165 |
+
|
166 |
+
parsing_im = Image.fromarray(vis_res[0])
|
167 |
+
parsing_im.save(output_path+'/{}.png'.format(output_name))
|
168 |
+
cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :])
|
169 |
+
|
170 |
+
end_time = timeit.default_timer()
|
171 |
+
print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time))
|
172 |
+
|
173 |
+
if __name__ == '__main__':
|
174 |
+
'''argparse begin'''
|
175 |
+
parser = argparse.ArgumentParser()
|
176 |
+
# parser.add_argument('--loadmodel',default=None,type=str)
|
177 |
+
parser.add_argument('--loadmodel', default='', type=str)
|
178 |
+
parser.add_argument('--img_path', default='', type=str)
|
179 |
+
parser.add_argument('--output_path', default='', type=str)
|
180 |
+
parser.add_argument('--output_name', default='', type=str)
|
181 |
+
parser.add_argument('--use_gpu', default=1, type=int)
|
182 |
+
opts = parser.parse_args()
|
183 |
+
|
184 |
+
net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20,
|
185 |
+
hidden_layers=128,
|
186 |
+
source_classes=7, )
|
187 |
+
if not opts.loadmodel == '':
|
188 |
+
x = torch.load(opts.loadmodel)
|
189 |
+
net.load_source_model(x)
|
190 |
+
print('load model:', opts.loadmodel)
|
191 |
+
else:
|
192 |
+
print('no model load !!!!!!!!')
|
193 |
+
raise RuntimeError('No model!!!!')
|
194 |
+
|
195 |
+
if opts.use_gpu >0 :
|
196 |
+
net.cuda()
|
197 |
+
use_gpu = True
|
198 |
+
else:
|
199 |
+
use_gpu = False
|
200 |
+
raise RuntimeError('must use the gpu!!!!')
|
201 |
+
|
202 |
+
inference(net=net, img_path=opts.img_path,output_path=opts.output_path , output_name=opts.output_name, use_gpu=use_gpu)
|
203 |
+
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/inference.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import timeit
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from datetime import datetime
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from collections import OrderedDict
|
9 |
+
sys.path.append('./')
|
10 |
+
# PyTorch includes
|
11 |
+
import torch
|
12 |
+
from torch.autograd import Variable
|
13 |
+
from torchvision import transforms
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
|
17 |
+
# Custom includes
|
18 |
+
from networks import deeplab_xception_transfer, graph
|
19 |
+
from dataloaders import custom_transforms as tr
|
20 |
+
|
21 |
+
#
|
22 |
+
import argparse
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
import warnings
|
26 |
+
warnings.filterwarnings("ignore")
|
27 |
+
|
28 |
+
label_colours = [(0,0,0)
|
29 |
+
, (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0)
|
30 |
+
, (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
|
31 |
+
|
32 |
+
|
33 |
+
def flip(x, dim):
|
34 |
+
indices = [slice(None)] * x.dim()
|
35 |
+
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
|
36 |
+
dtype=torch.long, device=x.device)
|
37 |
+
return x[tuple(indices)]
|
38 |
+
|
39 |
+
def flip_cihp(tail_list):
|
40 |
+
'''
|
41 |
+
|
42 |
+
:param tail_list: tail_list size is 1 x n_class x h x w
|
43 |
+
:return:
|
44 |
+
'''
|
45 |
+
# tail_list = tail_list[0]
|
46 |
+
tail_list_rev = [None] * 20
|
47 |
+
for xx in range(14):
|
48 |
+
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
|
49 |
+
tail_list_rev[14] = tail_list[15].unsqueeze(0)
|
50 |
+
tail_list_rev[15] = tail_list[14].unsqueeze(0)
|
51 |
+
tail_list_rev[16] = tail_list[17].unsqueeze(0)
|
52 |
+
tail_list_rev[17] = tail_list[16].unsqueeze(0)
|
53 |
+
tail_list_rev[18] = tail_list[19].unsqueeze(0)
|
54 |
+
tail_list_rev[19] = tail_list[18].unsqueeze(0)
|
55 |
+
return torch.cat(tail_list_rev,dim=0)
|
56 |
+
|
57 |
+
|
58 |
+
def decode_labels(mask, num_images=1, num_classes=20):
|
59 |
+
"""Decode batch of segmentation masks.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
mask: result of inference after taking argmax.
|
63 |
+
num_images: number of images to decode from the batch.
|
64 |
+
num_classes: number of classes to predict (including background).
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
A batch with num_images RGB images of the same size as the input.
|
68 |
+
"""
|
69 |
+
n, h, w = mask.shape
|
70 |
+
assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (
|
71 |
+
n, num_images)
|
72 |
+
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
|
73 |
+
for i in range(num_images):
|
74 |
+
img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
|
75 |
+
pixels = img.load()
|
76 |
+
for j_, j in enumerate(mask[i, :, :]):
|
77 |
+
for k_, k in enumerate(j):
|
78 |
+
if k < num_classes:
|
79 |
+
pixels[k_, j_] = label_colours[k]
|
80 |
+
outputs[i] = np.array(img)
|
81 |
+
return outputs
|
82 |
+
|
83 |
+
def read_img(img_path):
|
84 |
+
_img = Image.open(img_path).convert('RGB') # return is RGB pic
|
85 |
+
return _img
|
86 |
+
|
87 |
+
def img_transform(img, transform=None):
|
88 |
+
sample = {'image': img, 'label': 0}
|
89 |
+
|
90 |
+
sample = transform(sample)
|
91 |
+
return sample
|
92 |
+
|
93 |
+
def inference(net, img_path='', output_path='./', output_name='f', use_gpu=True):
|
94 |
+
'''
|
95 |
+
|
96 |
+
:param net:
|
97 |
+
:param img_path:
|
98 |
+
:param output_path:
|
99 |
+
:return:
|
100 |
+
'''
|
101 |
+
# adj
|
102 |
+
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
|
103 |
+
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3)
|
104 |
+
|
105 |
+
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
|
106 |
+
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
|
107 |
+
|
108 |
+
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
|
109 |
+
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
|
110 |
+
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
|
111 |
+
|
112 |
+
# multi-scale
|
113 |
+
scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
|
114 |
+
img = read_img(img_path)
|
115 |
+
testloader_list = []
|
116 |
+
testloader_flip_list = []
|
117 |
+
for pv in scale_list:
|
118 |
+
composed_transforms_ts = transforms.Compose([
|
119 |
+
tr.Scale_only_img(pv),
|
120 |
+
tr.Normalize_xception_tf_only_img(),
|
121 |
+
tr.ToTensor_only_img()])
|
122 |
+
|
123 |
+
composed_transforms_ts_flip = transforms.Compose([
|
124 |
+
tr.Scale_only_img(pv),
|
125 |
+
tr.HorizontalFlip_only_img(),
|
126 |
+
tr.Normalize_xception_tf_only_img(),
|
127 |
+
tr.ToTensor_only_img()])
|
128 |
+
|
129 |
+
testloader_list.append(img_transform(img, composed_transforms_ts))
|
130 |
+
# print(img_transform(img, composed_transforms_ts))
|
131 |
+
testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip))
|
132 |
+
# print(testloader_list)
|
133 |
+
start_time = timeit.default_timer()
|
134 |
+
# One testing epoch
|
135 |
+
net.eval()
|
136 |
+
# 1 0.5 0.75 1.25 1.5 1.75 ; flip:
|
137 |
+
|
138 |
+
for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)):
|
139 |
+
inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
|
140 |
+
inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
|
141 |
+
inputs = inputs.unsqueeze(0)
|
142 |
+
inputs_f = inputs_f.unsqueeze(0)
|
143 |
+
inputs = torch.cat((inputs, inputs_f), dim=0)
|
144 |
+
if iii == 0:
|
145 |
+
_, _, h, w = inputs.size()
|
146 |
+
# assert inputs.size() == inputs_f.size()
|
147 |
+
|
148 |
+
# Forward pass of the mini-batch
|
149 |
+
inputs = Variable(inputs, requires_grad=False)
|
150 |
+
|
151 |
+
with torch.no_grad():
|
152 |
+
if use_gpu >= 0:
|
153 |
+
inputs = inputs.cuda()
|
154 |
+
# outputs = net.forward(inputs)
|
155 |
+
outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
|
156 |
+
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
|
157 |
+
outputs = outputs.unsqueeze(0)
|
158 |
+
|
159 |
+
if iii > 0:
|
160 |
+
outputs = F.upsample(outputs, size=(h, w), mode='bilinear', align_corners=True)
|
161 |
+
outputs_final = outputs_final + outputs
|
162 |
+
else:
|
163 |
+
outputs_final = outputs.clone()
|
164 |
+
################ plot pic
|
165 |
+
predictions = torch.max(outputs_final, 1)[1]
|
166 |
+
results = predictions.cpu().numpy()
|
167 |
+
vis_res = decode_labels(results)
|
168 |
+
|
169 |
+
parsing_im = Image.fromarray(vis_res[0])
|
170 |
+
parsing_im.save(output_path+'/{}.png'.format(output_name))
|
171 |
+
cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :])
|
172 |
+
|
173 |
+
end_time = timeit.default_timer()
|
174 |
+
print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time))
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
'''argparse begin'''
|
178 |
+
parser = argparse.ArgumentParser()
|
179 |
+
# parser.add_argument('--loadmodel',default=None,type=str)
|
180 |
+
parser.add_argument('--loadmodel', default='', type=str)
|
181 |
+
parser.add_argument('--img_path', default='', type=str)
|
182 |
+
parser.add_argument('--output_path', default='', type=str)
|
183 |
+
parser.add_argument('--output_name', default='', type=str)
|
184 |
+
parser.add_argument('--use_gpu', default=1, type=int)
|
185 |
+
opts = parser.parse_args()
|
186 |
+
|
187 |
+
net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20,
|
188 |
+
hidden_layers=128,
|
189 |
+
source_classes=7, )
|
190 |
+
if not opts.loadmodel == '':
|
191 |
+
x = torch.load(opts.loadmodel)
|
192 |
+
net.load_source_model(x)
|
193 |
+
print('load model:', opts.loadmodel)
|
194 |
+
else:
|
195 |
+
print('no model load !!!!!!!!')
|
196 |
+
raise RuntimeError('No model!!!!')
|
197 |
+
|
198 |
+
if opts.use_gpu >0 :
|
199 |
+
net.cuda()
|
200 |
+
use_gpu = True
|
201 |
+
else:
|
202 |
+
use_gpu = False
|
203 |
+
raise RuntimeError('must use the gpu!!!!')
|
204 |
+
|
205 |
+
inference(net=net, img_path=opts.img_path,output_path=opts.output_path , output_name=opts.output_name, use_gpu=use_gpu)
|
206 |
+
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .test_from_disk import eval_
|
2 |
+
|
3 |
+
__all__ = ['eval_']
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_cihp2pascal.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import timeit
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from datetime import datetime
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import glob
|
9 |
+
from collections import OrderedDict
|
10 |
+
sys.path.append('../../')
|
11 |
+
# PyTorch includes
|
12 |
+
import torch
|
13 |
+
import pdb
|
14 |
+
from torch.autograd import Variable
|
15 |
+
import torch.optim as optim
|
16 |
+
from torchvision import transforms
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from torchvision.utils import make_grid
|
19 |
+
import cv2
|
20 |
+
|
21 |
+
# Tensorboard include
|
22 |
+
# from tensorboardX import SummaryWriter
|
23 |
+
|
24 |
+
# Custom includes
|
25 |
+
from dataloaders import pascal
|
26 |
+
from utils import util
|
27 |
+
from networks import deeplab_xception_transfer, graph
|
28 |
+
from dataloaders import custom_transforms as tr
|
29 |
+
|
30 |
+
#
|
31 |
+
import argparse
|
32 |
+
import copy
|
33 |
+
import torch.nn.functional as F
|
34 |
+
from test_from_disk import eval_
|
35 |
+
|
36 |
+
|
37 |
+
gpu_id = 1
|
38 |
+
|
39 |
+
label_colours = [(0,0,0)
|
40 |
+
# 0=background
|
41 |
+
,(128,0,0), (0,128,0), (128,128,0), (0,0,128), (128,0,128), (0,128,128)]
|
42 |
+
|
43 |
+
|
44 |
+
def flip(x, dim):
|
45 |
+
indices = [slice(None)] * x.dim()
|
46 |
+
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
|
47 |
+
dtype=torch.long, device=x.device)
|
48 |
+
return x[tuple(indices)]
|
49 |
+
|
50 |
+
# def flip_cihp(tail_list):
|
51 |
+
# '''
|
52 |
+
#
|
53 |
+
# :param tail_list: tail_list size is 1 x n_class x h x w
|
54 |
+
# :return:
|
55 |
+
# '''
|
56 |
+
# # tail_list = tail_list[0]
|
57 |
+
# tail_list_rev = [None] * 20
|
58 |
+
# for xx in range(14):
|
59 |
+
# tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
|
60 |
+
# tail_list_rev[14] = tail_list[15].unsqueeze(0)
|
61 |
+
# tail_list_rev[15] = tail_list[14].unsqueeze(0)
|
62 |
+
# tail_list_rev[16] = tail_list[17].unsqueeze(0)
|
63 |
+
# tail_list_rev[17] = tail_list[16].unsqueeze(0)
|
64 |
+
# tail_list_rev[18] = tail_list[19].unsqueeze(0)
|
65 |
+
# tail_list_rev[19] = tail_list[18].unsqueeze(0)
|
66 |
+
# return torch.cat(tail_list_rev,dim=0)
|
67 |
+
|
68 |
+
def decode_labels(mask, num_images=1, num_classes=20):
|
69 |
+
"""Decode batch of segmentation masks.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
mask: result of inference after taking argmax.
|
73 |
+
num_images: number of images to decode from the batch.
|
74 |
+
num_classes: number of classes to predict (including background).
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
A batch with num_images RGB images of the same size as the input.
|
78 |
+
"""
|
79 |
+
n, h, w = mask.shape
|
80 |
+
assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images)
|
81 |
+
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
|
82 |
+
for i in range(num_images):
|
83 |
+
img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
|
84 |
+
pixels = img.load()
|
85 |
+
for j_, j in enumerate(mask[i, :, :]):
|
86 |
+
for k_, k in enumerate(j):
|
87 |
+
if k < num_classes:
|
88 |
+
pixels[k_,j_] = label_colours[k]
|
89 |
+
outputs[i] = np.array(img)
|
90 |
+
return outputs
|
91 |
+
|
92 |
+
def get_parser():
|
93 |
+
'''argparse begin'''
|
94 |
+
parser = argparse.ArgumentParser()
|
95 |
+
LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
|
96 |
+
|
97 |
+
parser.add_argument('--epochs', default=100, type=int)
|
98 |
+
parser.add_argument('--batch', default=16, type=int)
|
99 |
+
parser.add_argument('--lr', default=1e-7, type=float)
|
100 |
+
parser.add_argument('--numworker', default=12, type=int)
|
101 |
+
parser.add_argument('--step', default=30, type=int)
|
102 |
+
# parser.add_argument('--loadmodel',default=None,type=str)
|
103 |
+
parser.add_argument('--classes', default=7, type=int)
|
104 |
+
parser.add_argument('--testepoch', default=10, type=int)
|
105 |
+
parser.add_argument('--loadmodel', default='', type=str)
|
106 |
+
parser.add_argument('--txt_file', default='', type=str)
|
107 |
+
parser.add_argument('--hidden_layers', default=128, type=int)
|
108 |
+
parser.add_argument('--gpus', default=4, type=int)
|
109 |
+
parser.add_argument('--output_path', default='./results/', type=str)
|
110 |
+
parser.add_argument('--gt_path', default='./results/', type=str)
|
111 |
+
opts = parser.parse_args()
|
112 |
+
return opts
|
113 |
+
|
114 |
+
|
115 |
+
def main(opts):
|
116 |
+
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
|
117 |
+
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda()
|
118 |
+
|
119 |
+
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
|
120 |
+
adj1_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
|
121 |
+
|
122 |
+
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
|
123 |
+
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
|
124 |
+
adj3_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
|
125 |
+
|
126 |
+
p = OrderedDict() # Parameters to include in report
|
127 |
+
p['trainBatch'] = opts.batch # Training batch size
|
128 |
+
p['nAveGrad'] = 1 # Average the gradient of several iterations
|
129 |
+
p['lr'] = opts.lr # Learning rate
|
130 |
+
p['lrFtr'] = 1e-5
|
131 |
+
p['lraspp'] = 1e-5
|
132 |
+
p['lrpro'] = 1e-5
|
133 |
+
p['lrdecoder'] = 1e-5
|
134 |
+
p['lrother'] = 1e-5
|
135 |
+
p['wd'] = 5e-4 # Weight decay
|
136 |
+
p['momentum'] = 0.9 # Momentum
|
137 |
+
p['epoch_size'] = 10 # How many epochs to change learning rate
|
138 |
+
p['num_workers'] = opts.numworker
|
139 |
+
backbone = 'xception' # Use xception or resnet as feature extractor,
|
140 |
+
|
141 |
+
with open(opts.txt_file, 'r') as f:
|
142 |
+
img_list = f.readlines()
|
143 |
+
|
144 |
+
max_id = 0
|
145 |
+
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
146 |
+
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
|
147 |
+
runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
|
148 |
+
for r in runs:
|
149 |
+
run_id = int(r.split('_')[-1])
|
150 |
+
if run_id >= max_id:
|
151 |
+
max_id = run_id + 1
|
152 |
+
# run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
|
153 |
+
|
154 |
+
# Network definition
|
155 |
+
if backbone == 'xception':
|
156 |
+
net = deeplab_xception_transfer.deeplab_xception_transfer_projection(n_classes=opts.classes, os=16,
|
157 |
+
hidden_layers=opts.hidden_layers, source_classes=20,
|
158 |
+
)
|
159 |
+
elif backbone == 'resnet':
|
160 |
+
# net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
|
161 |
+
raise NotImplementedError
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
|
165 |
+
if gpu_id >= 0:
|
166 |
+
net.cuda()
|
167 |
+
|
168 |
+
# net load weights
|
169 |
+
if not opts.loadmodel =='':
|
170 |
+
x = torch.load(opts.loadmodel)
|
171 |
+
net.load_source_model(x)
|
172 |
+
print('load model:' ,opts.loadmodel)
|
173 |
+
else:
|
174 |
+
print('no model load !!!!!!!!')
|
175 |
+
|
176 |
+
## multi scale
|
177 |
+
scale_list=[1,0.5,0.75,1.25,1.5,1.75]
|
178 |
+
testloader_list = []
|
179 |
+
testloader_flip_list = []
|
180 |
+
for pv in scale_list:
|
181 |
+
composed_transforms_ts = transforms.Compose([
|
182 |
+
tr.Scale_(pv),
|
183 |
+
tr.Normalize_xception_tf(),
|
184 |
+
tr.ToTensor_()])
|
185 |
+
|
186 |
+
composed_transforms_ts_flip = transforms.Compose([
|
187 |
+
tr.Scale_(pv),
|
188 |
+
tr.HorizontalFlip(),
|
189 |
+
tr.Normalize_xception_tf(),
|
190 |
+
tr.ToTensor_()])
|
191 |
+
|
192 |
+
voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
|
193 |
+
voc_val_f = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
|
194 |
+
|
195 |
+
testloader = DataLoader(voc_val, batch_size=1, shuffle=False, num_workers=p['num_workers'])
|
196 |
+
testloader_flip = DataLoader(voc_val_f, batch_size=1, shuffle=False, num_workers=p['num_workers'])
|
197 |
+
|
198 |
+
testloader_list.append(copy.deepcopy(testloader))
|
199 |
+
testloader_flip_list.append(copy.deepcopy(testloader_flip))
|
200 |
+
|
201 |
+
print("Eval Network")
|
202 |
+
|
203 |
+
if not os.path.exists(opts.output_path + 'pascal_output_vis/'):
|
204 |
+
os.makedirs(opts.output_path + 'pascal_output_vis/')
|
205 |
+
if not os.path.exists(opts.output_path + 'pascal_output/'):
|
206 |
+
os.makedirs(opts.output_path + 'pascal_output/')
|
207 |
+
|
208 |
+
start_time = timeit.default_timer()
|
209 |
+
# One testing epoch
|
210 |
+
total_iou = 0.0
|
211 |
+
net.eval()
|
212 |
+
for ii, large_sample_batched in enumerate(zip(*testloader_list, *testloader_flip_list)):
|
213 |
+
print(ii)
|
214 |
+
#1 0.5 0.75 1.25 1.5 1.75 ; flip:
|
215 |
+
sample1 = large_sample_batched[:6]
|
216 |
+
sample2 = large_sample_batched[6:]
|
217 |
+
for iii,sample_batched in enumerate(zip(sample1,sample2)):
|
218 |
+
inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
|
219 |
+
inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
|
220 |
+
inputs = torch.cat((inputs,inputs_f),dim=0)
|
221 |
+
if iii == 0:
|
222 |
+
_,_,h,w = inputs.size()
|
223 |
+
# assert inputs.size() == inputs_f.size()
|
224 |
+
|
225 |
+
# Forward pass of the mini-batch
|
226 |
+
inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
|
227 |
+
|
228 |
+
with torch.no_grad():
|
229 |
+
if gpu_id >= 0:
|
230 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
231 |
+
# outputs = net.forward(inputs)
|
232 |
+
# pdb.set_trace()
|
233 |
+
outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
|
234 |
+
outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2
|
235 |
+
outputs = outputs.unsqueeze(0)
|
236 |
+
|
237 |
+
if iii>0:
|
238 |
+
outputs = F.upsample(outputs,size=(h,w),mode='bilinear',align_corners=True)
|
239 |
+
outputs_final = outputs_final + outputs
|
240 |
+
else:
|
241 |
+
outputs_final = outputs.clone()
|
242 |
+
################ plot pic
|
243 |
+
predictions = torch.max(outputs_final, 1)[1]
|
244 |
+
prob_predictions = torch.max(outputs_final,1)[0]
|
245 |
+
results = predictions.cpu().numpy()
|
246 |
+
prob_results = prob_predictions.cpu().numpy()
|
247 |
+
vis_res = decode_labels(results)
|
248 |
+
|
249 |
+
parsing_im = Image.fromarray(vis_res[0])
|
250 |
+
parsing_im.save(opts.output_path + 'pascal_output_vis/{}.png'.format(img_list[ii][:-1]))
|
251 |
+
cv2.imwrite(opts.output_path + 'pascal_output/{}.png'.format(img_list[ii][:-1]), results[0,:,:])
|
252 |
+
# np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :])
|
253 |
+
# pred_list.append(predictions.cpu())
|
254 |
+
# label_list.append(labels.squeeze(1).cpu())
|
255 |
+
# loss = criterion(outputs, labels, batch_average=True)
|
256 |
+
# running_loss_ts += loss.item()
|
257 |
+
|
258 |
+
# total_iou += utils.get_iou(predictions, labels)
|
259 |
+
end_time = timeit.default_timer()
|
260 |
+
print('time use for '+str(ii) + ' is :' + str(end_time - start_time))
|
261 |
+
|
262 |
+
# Eval
|
263 |
+
pred_path = opts.output_path + 'pascal_output/'
|
264 |
+
eval_(pred_path=pred_path, gt_path=opts.gt_path,classes=opts.classes, txt_file=opts.txt_file)
|
265 |
+
|
266 |
+
if __name__ == '__main__':
|
267 |
+
opts = get_parser()
|
268 |
+
main(opts)
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_pascal2cihp.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import timeit
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from datetime import datetime
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import glob
|
9 |
+
from collections import OrderedDict
|
10 |
+
sys.path.append('../../')
|
11 |
+
# PyTorch includes
|
12 |
+
import torch
|
13 |
+
import pdb
|
14 |
+
from torch.autograd import Variable
|
15 |
+
import torch.optim as optim
|
16 |
+
from torchvision import transforms
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from torchvision.utils import make_grid
|
19 |
+
import cv2
|
20 |
+
|
21 |
+
# Tensorboard include
|
22 |
+
# from tensorboardX import SummaryWriter
|
23 |
+
|
24 |
+
# Custom includes
|
25 |
+
from dataloaders import cihp
|
26 |
+
from utils import util
|
27 |
+
from networks import deeplab_xception_transfer, graph
|
28 |
+
from dataloaders import custom_transforms as tr
|
29 |
+
|
30 |
+
#
|
31 |
+
import argparse
|
32 |
+
import copy
|
33 |
+
import torch.nn.functional as F
|
34 |
+
from test_from_disk import eval_
|
35 |
+
|
36 |
+
|
37 |
+
gpu_id = 1
|
38 |
+
|
39 |
+
label_colours = [(0,0,0)
|
40 |
+
, (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0)
|
41 |
+
, (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
|
42 |
+
|
43 |
+
|
44 |
+
def flip(x, dim):
|
45 |
+
indices = [slice(None)] * x.dim()
|
46 |
+
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
|
47 |
+
dtype=torch.long, device=x.device)
|
48 |
+
return x[tuple(indices)]
|
49 |
+
|
50 |
+
def flip_cihp(tail_list):
|
51 |
+
'''
|
52 |
+
|
53 |
+
:param tail_list: tail_list size is 1 x n_class x h x w
|
54 |
+
:return:
|
55 |
+
'''
|
56 |
+
# tail_list = tail_list[0]
|
57 |
+
tail_list_rev = [None] * 20
|
58 |
+
for xx in range(14):
|
59 |
+
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
|
60 |
+
tail_list_rev[14] = tail_list[15].unsqueeze(0)
|
61 |
+
tail_list_rev[15] = tail_list[14].unsqueeze(0)
|
62 |
+
tail_list_rev[16] = tail_list[17].unsqueeze(0)
|
63 |
+
tail_list_rev[17] = tail_list[16].unsqueeze(0)
|
64 |
+
tail_list_rev[18] = tail_list[19].unsqueeze(0)
|
65 |
+
tail_list_rev[19] = tail_list[18].unsqueeze(0)
|
66 |
+
return torch.cat(tail_list_rev,dim=0)
|
67 |
+
|
68 |
+
def decode_labels(mask, num_images=1, num_classes=20):
|
69 |
+
"""Decode batch of segmentation masks.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
mask: result of inference after taking argmax.
|
73 |
+
num_images: number of images to decode from the batch.
|
74 |
+
num_classes: number of classes to predict (including background).
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
A batch with num_images RGB images of the same size as the input.
|
78 |
+
"""
|
79 |
+
n, h, w = mask.shape
|
80 |
+
assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images)
|
81 |
+
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
|
82 |
+
for i in range(num_images):
|
83 |
+
img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
|
84 |
+
pixels = img.load()
|
85 |
+
for j_, j in enumerate(mask[i, :, :]):
|
86 |
+
for k_, k in enumerate(j):
|
87 |
+
if k < num_classes:
|
88 |
+
pixels[k_,j_] = label_colours[k]
|
89 |
+
outputs[i] = np.array(img)
|
90 |
+
return outputs
|
91 |
+
|
92 |
+
def get_parser():
|
93 |
+
'''argparse begin'''
|
94 |
+
parser = argparse.ArgumentParser()
|
95 |
+
LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
|
96 |
+
|
97 |
+
parser.add_argument('--epochs', default=100, type=int)
|
98 |
+
parser.add_argument('--batch', default=16, type=int)
|
99 |
+
parser.add_argument('--lr', default=1e-7, type=float)
|
100 |
+
parser.add_argument('--numworker', default=12, type=int)
|
101 |
+
parser.add_argument('--step', default=30, type=int)
|
102 |
+
# parser.add_argument('--loadmodel',default=None,type=str)
|
103 |
+
parser.add_argument('--classes', default=7, type=int)
|
104 |
+
parser.add_argument('--testepoch', default=10, type=int)
|
105 |
+
parser.add_argument('--loadmodel', default='', type=str)
|
106 |
+
parser.add_argument('--txt_file', default='', type=str)
|
107 |
+
parser.add_argument('--hidden_layers', default=128, type=int)
|
108 |
+
parser.add_argument('--gpus', default=4, type=int)
|
109 |
+
parser.add_argument('--output_path', default='./results/', type=str)
|
110 |
+
parser.add_argument('--gt_path', default='./results/', type=str)
|
111 |
+
opts = parser.parse_args()
|
112 |
+
return opts
|
113 |
+
|
114 |
+
|
115 |
+
def main(opts):
|
116 |
+
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
|
117 |
+
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3)
|
118 |
+
|
119 |
+
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
|
120 |
+
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
|
121 |
+
|
122 |
+
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
|
123 |
+
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
|
124 |
+
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
|
125 |
+
|
126 |
+
p = OrderedDict() # Parameters to include in report
|
127 |
+
p['trainBatch'] = opts.batch # Training batch size
|
128 |
+
p['nAveGrad'] = 1 # Average the gradient of several iterations
|
129 |
+
p['lr'] = opts.lr # Learning rate
|
130 |
+
p['lrFtr'] = 1e-5
|
131 |
+
p['lraspp'] = 1e-5
|
132 |
+
p['lrpro'] = 1e-5
|
133 |
+
p['lrdecoder'] = 1e-5
|
134 |
+
p['lrother'] = 1e-5
|
135 |
+
p['wd'] = 5e-4 # Weight decay
|
136 |
+
p['momentum'] = 0.9 # Momentum
|
137 |
+
p['epoch_size'] = 10 # How many epochs to change learning rate
|
138 |
+
p['num_workers'] = opts.numworker
|
139 |
+
backbone = 'xception' # Use xception or resnet as feature extractor,
|
140 |
+
|
141 |
+
with open(opts.txt_file, 'r') as f:
|
142 |
+
img_list = f.readlines()
|
143 |
+
|
144 |
+
max_id = 0
|
145 |
+
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
146 |
+
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
|
147 |
+
runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
|
148 |
+
for r in runs:
|
149 |
+
run_id = int(r.split('_')[-1])
|
150 |
+
if run_id >= max_id:
|
151 |
+
max_id = run_id + 1
|
152 |
+
# run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
|
153 |
+
|
154 |
+
# Network definition
|
155 |
+
if backbone == 'xception':
|
156 |
+
net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=opts.classes, os=16,
|
157 |
+
hidden_layers=opts.hidden_layers, source_classes=7,
|
158 |
+
)
|
159 |
+
elif backbone == 'resnet':
|
160 |
+
# net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
|
161 |
+
raise NotImplementedError
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
|
165 |
+
if gpu_id >= 0:
|
166 |
+
net.cuda()
|
167 |
+
|
168 |
+
# net load weights
|
169 |
+
if not opts.loadmodel =='':
|
170 |
+
x = torch.load(opts.loadmodel)
|
171 |
+
net.load_source_model(x)
|
172 |
+
print('load model:' ,opts.loadmodel)
|
173 |
+
else:
|
174 |
+
print('no model load !!!!!!!!')
|
175 |
+
|
176 |
+
## multi scale
|
177 |
+
scale_list=[1,0.5,0.75,1.25,1.5,1.75]
|
178 |
+
testloader_list = []
|
179 |
+
testloader_flip_list = []
|
180 |
+
for pv in scale_list:
|
181 |
+
composed_transforms_ts = transforms.Compose([
|
182 |
+
tr.Scale_(pv),
|
183 |
+
tr.Normalize_xception_tf(),
|
184 |
+
tr.ToTensor_()])
|
185 |
+
|
186 |
+
composed_transforms_ts_flip = transforms.Compose([
|
187 |
+
tr.Scale_(pv),
|
188 |
+
tr.HorizontalFlip(),
|
189 |
+
tr.Normalize_xception_tf(),
|
190 |
+
tr.ToTensor_()])
|
191 |
+
|
192 |
+
voc_val = cihp.VOCSegmentation(split='test', transform=composed_transforms_ts)
|
193 |
+
voc_val_f = cihp.VOCSegmentation(split='test', transform=composed_transforms_ts_flip)
|
194 |
+
|
195 |
+
testloader = DataLoader(voc_val, batch_size=1, shuffle=False, num_workers=p['num_workers'])
|
196 |
+
testloader_flip = DataLoader(voc_val_f, batch_size=1, shuffle=False, num_workers=p['num_workers'])
|
197 |
+
|
198 |
+
testloader_list.append(copy.deepcopy(testloader))
|
199 |
+
testloader_flip_list.append(copy.deepcopy(testloader_flip))
|
200 |
+
|
201 |
+
print("Eval Network")
|
202 |
+
|
203 |
+
if not os.path.exists(opts.output_path + 'cihp_output_vis/'):
|
204 |
+
os.makedirs(opts.output_path + 'cihp_output_vis/')
|
205 |
+
if not os.path.exists(opts.output_path + 'cihp_output/'):
|
206 |
+
os.makedirs(opts.output_path + 'cihp_output/')
|
207 |
+
|
208 |
+
start_time = timeit.default_timer()
|
209 |
+
# One testing epoch
|
210 |
+
total_iou = 0.0
|
211 |
+
net.eval()
|
212 |
+
for ii, large_sample_batched in enumerate(zip(*testloader_list, *testloader_flip_list)):
|
213 |
+
print(ii)
|
214 |
+
#1 0.5 0.75 1.25 1.5 1.75 ; flip:
|
215 |
+
sample1 = large_sample_batched[:6]
|
216 |
+
sample2 = large_sample_batched[6:]
|
217 |
+
for iii,sample_batched in enumerate(zip(sample1,sample2)):
|
218 |
+
inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
|
219 |
+
inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
|
220 |
+
inputs = torch.cat((inputs,inputs_f),dim=0)
|
221 |
+
if iii == 0:
|
222 |
+
_,_,h,w = inputs.size()
|
223 |
+
# assert inputs.size() == inputs_f.size()
|
224 |
+
|
225 |
+
# Forward pass of the mini-batch
|
226 |
+
inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
|
227 |
+
|
228 |
+
with torch.no_grad():
|
229 |
+
if gpu_id >= 0:
|
230 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
231 |
+
# outputs = net.forward(inputs)
|
232 |
+
# pdb.set_trace()
|
233 |
+
outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
|
234 |
+
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
|
235 |
+
outputs = outputs.unsqueeze(0)
|
236 |
+
|
237 |
+
if iii>0:
|
238 |
+
outputs = F.upsample(outputs,size=(h,w),mode='bilinear',align_corners=True)
|
239 |
+
outputs_final = outputs_final + outputs
|
240 |
+
else:
|
241 |
+
outputs_final = outputs.clone()
|
242 |
+
################ plot pic
|
243 |
+
predictions = torch.max(outputs_final, 1)[1]
|
244 |
+
prob_predictions = torch.max(outputs_final,1)[0]
|
245 |
+
results = predictions.cpu().numpy()
|
246 |
+
prob_results = prob_predictions.cpu().numpy()
|
247 |
+
vis_res = decode_labels(results)
|
248 |
+
|
249 |
+
parsing_im = Image.fromarray(vis_res[0])
|
250 |
+
parsing_im.save(opts.output_path + 'cihp_output_vis/{}.png'.format(img_list[ii][:-1]))
|
251 |
+
cv2.imwrite(opts.output_path + 'cihp_output/{}.png'.format(img_list[ii][:-1]), results[0,:,:])
|
252 |
+
# np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :])
|
253 |
+
# pred_list.append(predictions.cpu())
|
254 |
+
# label_list.append(labels.squeeze(1).cpu())
|
255 |
+
# loss = criterion(outputs, labels, batch_average=True)
|
256 |
+
# running_loss_ts += loss.item()
|
257 |
+
|
258 |
+
# total_iou += utils.get_iou(predictions, labels)
|
259 |
+
end_time = timeit.default_timer()
|
260 |
+
print('time use for '+str(ii) + ' is :' + str(end_time - start_time))
|
261 |
+
|
262 |
+
# Eval
|
263 |
+
pred_path = opts.output_path + 'cihp_output/'
|
264 |
+
eval_(pred_path=pred_path, gt_path=opts.gt_path,classes=opts.classes, txt_file=opts.txt_file)
|
265 |
+
|
266 |
+
if __name__ == '__main__':
|
267 |
+
opts = get_parser()
|
268 |
+
main(opts)
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/test_from_disk.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
# PyTorch includes
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from utils import test_human
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
#
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
def get_parser():
|
14 |
+
'''argparse begin'''
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
|
17 |
+
|
18 |
+
parser.add_argument('--epochs', default=100, type=int)
|
19 |
+
parser.add_argument('--batch', default=16, type=int)
|
20 |
+
parser.add_argument('--lr', default=1e-7, type=float)
|
21 |
+
parser.add_argument('--numworker',default=12,type=int)
|
22 |
+
parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
|
23 |
+
parser.add_argument('--step', default=30, type=int)
|
24 |
+
parser.add_argument('--txt_file',default=None,type=str)
|
25 |
+
parser.add_argument('--pred_path',default=None,type=str)
|
26 |
+
parser.add_argument('--gt_path',default=None,type=str)
|
27 |
+
parser.add_argument('--classes', default=7, type=int)
|
28 |
+
parser.add_argument('--testepoch', default=10, type=int)
|
29 |
+
opts = parser.parse_args()
|
30 |
+
return opts
|
31 |
+
|
32 |
+
def eval_(pred_path, gt_path, classes, txt_file):
|
33 |
+
pred_path = pred_path
|
34 |
+
gt_path = gt_path
|
35 |
+
|
36 |
+
with open(txt_file,) as f:
|
37 |
+
lines = f.readlines()
|
38 |
+
lines = [x.strip() for x in lines]
|
39 |
+
|
40 |
+
output_list = []
|
41 |
+
label_list = []
|
42 |
+
for i,file in enumerate(lines):
|
43 |
+
print(i)
|
44 |
+
file_name = file + '.png'
|
45 |
+
try:
|
46 |
+
predict_pic = np.array(Image.open(pred_path+file_name))
|
47 |
+
gt_pic = np.array(Image.open(gt_path+file_name))
|
48 |
+
output_list.append(torch.from_numpy(predict_pic))
|
49 |
+
label_list.append(torch.from_numpy(gt_pic))
|
50 |
+
except:
|
51 |
+
print(file_name,flush=True)
|
52 |
+
raise RuntimeError('no predict/gt image.')
|
53 |
+
# gt_pic = np.array(Image.open(gt_path + file_name))
|
54 |
+
# output_list.append(torch.from_numpy(gt_pic))
|
55 |
+
# label_list.append(torch.from_numpy(gt_pic))
|
56 |
+
|
57 |
+
|
58 |
+
miou = test_human.get_iou_from_list(output_list, label_list, n_cls=classes)
|
59 |
+
|
60 |
+
print('Validation:')
|
61 |
+
print('MIoU: %f\n' % miou)
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
opts = get_parser()
|
65 |
+
eval_(pred_path=opts.pred_path, gt_path=opts.gt_path, classes=opts.classes, txt_file=opts.txt_file)
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/transfer/train_cihp_from_pascal.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import timeit
|
3 |
+
from datetime import datetime
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import glob
|
7 |
+
import numpy as np
|
8 |
+
from collections import OrderedDict
|
9 |
+
sys.path.append('../../')
|
10 |
+
sys.path.append('../../networks/')
|
11 |
+
# PyTorch includes
|
12 |
+
import torch
|
13 |
+
from torch.autograd import Variable
|
14 |
+
import torch.optim as optim
|
15 |
+
from torchvision import transforms
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torchvision.utils import make_grid
|
18 |
+
|
19 |
+
|
20 |
+
# Tensorboard include
|
21 |
+
from tensorboardX import SummaryWriter
|
22 |
+
|
23 |
+
# Custom includes
|
24 |
+
from dataloaders import cihp
|
25 |
+
from utils import util,get_iou_from_list
|
26 |
+
from networks import deeplab_xception_transfer, graph
|
27 |
+
from dataloaders import custom_transforms as tr
|
28 |
+
|
29 |
+
#
|
30 |
+
import argparse
|
31 |
+
|
32 |
+
gpu_id = 0
|
33 |
+
|
34 |
+
nEpochs = 100 # Number of epochs for training
|
35 |
+
resume_epoch = 0 # Default is 0, change if want to resume
|
36 |
+
|
37 |
+
def flip(x, dim):
|
38 |
+
indices = [slice(None)] * x.dim()
|
39 |
+
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
|
40 |
+
dtype=torch.long, device=x.device)
|
41 |
+
return x[tuple(indices)]
|
42 |
+
|
43 |
+
def flip_cihp(tail_list):
|
44 |
+
'''
|
45 |
+
|
46 |
+
:param tail_list: tail_list size is 1 x n_class x h x w
|
47 |
+
:return:
|
48 |
+
'''
|
49 |
+
# tail_list = tail_list[0]
|
50 |
+
tail_list_rev = [None] * 20
|
51 |
+
for xx in range(14):
|
52 |
+
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
|
53 |
+
tail_list_rev[14] = tail_list[15].unsqueeze(0)
|
54 |
+
tail_list_rev[15] = tail_list[14].unsqueeze(0)
|
55 |
+
tail_list_rev[16] = tail_list[17].unsqueeze(0)
|
56 |
+
tail_list_rev[17] = tail_list[16].unsqueeze(0)
|
57 |
+
tail_list_rev[18] = tail_list[19].unsqueeze(0)
|
58 |
+
tail_list_rev[19] = tail_list[18].unsqueeze(0)
|
59 |
+
return torch.cat(tail_list_rev,dim=0)
|
60 |
+
|
61 |
+
def get_parser():
|
62 |
+
'''argparse begin'''
|
63 |
+
parser = argparse.ArgumentParser()
|
64 |
+
LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
|
65 |
+
|
66 |
+
parser.add_argument('--epochs', default=100, type=int)
|
67 |
+
parser.add_argument('--batch', default=16, type=int)
|
68 |
+
parser.add_argument('--lr', default=1e-7, type=float)
|
69 |
+
parser.add_argument('--numworker',default=12,type=int)
|
70 |
+
parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
|
71 |
+
parser.add_argument('--step', default=10, type=int)
|
72 |
+
parser.add_argument('--classes', default=20, type=int)
|
73 |
+
parser.add_argument('--testInterval', default=10, type=int)
|
74 |
+
parser.add_argument('--loadmodel',default='',type=str)
|
75 |
+
parser.add_argument('--pretrainedModel', default='', type=str)
|
76 |
+
parser.add_argument('--hidden_layers',default=128,type=int)
|
77 |
+
parser.add_argument('--gpus',default=4, type=int)
|
78 |
+
|
79 |
+
opts = parser.parse_args()
|
80 |
+
return opts
|
81 |
+
|
82 |
+
def get_graphs(opts):
|
83 |
+
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
|
84 |
+
adj2 = adj2_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20).transpose(2, 3).cuda()
|
85 |
+
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).transpose(2, 3)
|
86 |
+
|
87 |
+
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
|
88 |
+
adj3 = adj1_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 7).cuda()
|
89 |
+
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7)
|
90 |
+
|
91 |
+
# adj2 = torch.from_numpy(graph.cihp2pascal_adj).float()
|
92 |
+
# adj2 = adj2.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20)
|
93 |
+
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
|
94 |
+
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
|
95 |
+
adj1 = adj3_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 20).cuda()
|
96 |
+
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20)
|
97 |
+
train_graph = [adj1, adj2, adj3]
|
98 |
+
test_graph = [adj1_test, adj2_test, adj3_test]
|
99 |
+
return train_graph, test_graph
|
100 |
+
|
101 |
+
|
102 |
+
def val_cihp(net_, testloader, testloader_flip, test_graph, epoch, writer, criterion, classes=20):
|
103 |
+
adj1_test, adj2_test, adj3_test = test_graph
|
104 |
+
num_img_ts = len(testloader)
|
105 |
+
net_.eval()
|
106 |
+
pred_list = []
|
107 |
+
label_list = []
|
108 |
+
running_loss_ts = 0.0
|
109 |
+
miou = 0
|
110 |
+
for ii, sample_batched in enumerate(zip(testloader, testloader_flip)):
|
111 |
+
|
112 |
+
inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
|
113 |
+
inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
|
114 |
+
inputs = torch.cat((inputs, inputs_f), dim=0)
|
115 |
+
# Forward pass of the mini-batch
|
116 |
+
inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
|
117 |
+
if gpu_id >= 0:
|
118 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
outputs = net_.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
|
122 |
+
# pdb.set_trace()
|
123 |
+
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
|
124 |
+
outputs = outputs.unsqueeze(0)
|
125 |
+
predictions = torch.max(outputs, 1)[1]
|
126 |
+
pred_list.append(predictions.cpu())
|
127 |
+
label_list.append(labels.squeeze(1).cpu())
|
128 |
+
loss = criterion(outputs, labels, batch_average=True)
|
129 |
+
running_loss_ts += loss.item()
|
130 |
+
# total_iou += utils.get_iou(predictions, labels)
|
131 |
+
# Print stuff
|
132 |
+
if ii % num_img_ts == num_img_ts - 1:
|
133 |
+
# if ii == 10:
|
134 |
+
miou = get_iou_from_list(pred_list, label_list, n_cls=classes)
|
135 |
+
running_loss_ts = running_loss_ts / num_img_ts
|
136 |
+
|
137 |
+
print('Validation:')
|
138 |
+
print('[Epoch: %d, numImages: %5d]' % (epoch, ii * 1 + inputs.data.shape[0]))
|
139 |
+
writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch)
|
140 |
+
writer.add_scalar('data/test_miour', miou, epoch)
|
141 |
+
print('Loss: %f' % running_loss_ts)
|
142 |
+
print('MIoU: %f\n' % miou)
|
143 |
+
|
144 |
+
|
145 |
+
def main(opts):
|
146 |
+
p = OrderedDict() # Parameters to include in report
|
147 |
+
p['trainBatch'] = opts.batch # Training batch size
|
148 |
+
testBatch = 1 # Testing batch size
|
149 |
+
useTest = True # See evolution of the test set when training
|
150 |
+
nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs
|
151 |
+
snapshot = 1 # Store a model every snapshot epochs
|
152 |
+
p['nAveGrad'] = 1 # Average the gradient of several iterations
|
153 |
+
p['lr'] = opts.lr # Learning rate
|
154 |
+
p['lrFtr'] = 1e-5
|
155 |
+
p['lraspp'] = 1e-5
|
156 |
+
p['lrpro'] = 1e-5
|
157 |
+
p['lrdecoder'] = 1e-5
|
158 |
+
p['lrother'] = 1e-5
|
159 |
+
p['wd'] = 5e-4 # Weight decay
|
160 |
+
p['momentum'] = 0.9 # Momentum
|
161 |
+
p['epoch_size'] = opts.step # How many epochs to change learning rate
|
162 |
+
p['num_workers'] = opts.numworker
|
163 |
+
model_path = opts.pretrainedModel
|
164 |
+
backbone = 'xception' # Use xception or resnet as feature extractor,
|
165 |
+
nEpochs = opts.epochs
|
166 |
+
|
167 |
+
max_id = 0
|
168 |
+
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
169 |
+
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
|
170 |
+
runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*'))
|
171 |
+
for r in runs:
|
172 |
+
run_id = int(r.split('_')[-1])
|
173 |
+
if run_id >= max_id:
|
174 |
+
max_id = run_id + 1
|
175 |
+
save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id))
|
176 |
+
|
177 |
+
# Network definition
|
178 |
+
if backbone == 'xception':
|
179 |
+
net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=opts.classes, os=16,
|
180 |
+
hidden_layers=opts.hidden_layers, source_classes=7, )
|
181 |
+
elif backbone == 'resnet':
|
182 |
+
# net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
|
183 |
+
raise NotImplementedError
|
184 |
+
else:
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S')
|
188 |
+
criterion = util.cross_entropy2d
|
189 |
+
|
190 |
+
if gpu_id >= 0:
|
191 |
+
# torch.cuda.set_device(device=gpu_id)
|
192 |
+
net_.cuda()
|
193 |
+
|
194 |
+
# net load weights
|
195 |
+
if not model_path == '':
|
196 |
+
x = torch.load(model_path)
|
197 |
+
net_.load_state_dict_new(x)
|
198 |
+
print('load pretrainedModel:', model_path)
|
199 |
+
else:
|
200 |
+
print('no pretrainedModel.')
|
201 |
+
if not opts.loadmodel =='':
|
202 |
+
x = torch.load(opts.loadmodel)
|
203 |
+
net_.load_source_model(x)
|
204 |
+
print('load model:' ,opts.loadmodel)
|
205 |
+
else:
|
206 |
+
print('no model load !!!!!!!!')
|
207 |
+
|
208 |
+
log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
|
209 |
+
writer = SummaryWriter(log_dir=log_dir)
|
210 |
+
writer.add_text('load model',opts.loadmodel,1)
|
211 |
+
writer.add_text('setting',sys.argv[0],1)
|
212 |
+
|
213 |
+
if opts.freezeBN:
|
214 |
+
net_.freeze_bn()
|
215 |
+
|
216 |
+
# Use the following optimizer
|
217 |
+
optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
|
218 |
+
|
219 |
+
composed_transforms_tr = transforms.Compose([
|
220 |
+
tr.RandomSized_new(512),
|
221 |
+
tr.Normalize_xception_tf(),
|
222 |
+
tr.ToTensor_()])
|
223 |
+
|
224 |
+
composed_transforms_ts = transforms.Compose([
|
225 |
+
tr.Normalize_xception_tf(),
|
226 |
+
tr.ToTensor_()])
|
227 |
+
|
228 |
+
composed_transforms_ts_flip = transforms.Compose([
|
229 |
+
tr.HorizontalFlip(),
|
230 |
+
tr.Normalize_xception_tf(),
|
231 |
+
tr.ToTensor_()])
|
232 |
+
|
233 |
+
voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
|
234 |
+
voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts)
|
235 |
+
voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
|
236 |
+
|
237 |
+
trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=p['num_workers'],drop_last=True)
|
238 |
+
testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
|
239 |
+
testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
|
240 |
+
|
241 |
+
num_img_tr = len(trainloader)
|
242 |
+
num_img_ts = len(testloader)
|
243 |
+
running_loss_tr = 0.0
|
244 |
+
running_loss_ts = 0.0
|
245 |
+
aveGrad = 0
|
246 |
+
global_step = 0
|
247 |
+
print("Training Network")
|
248 |
+
|
249 |
+
net = torch.nn.DataParallel(net_)
|
250 |
+
train_graph, test_graph = get_graphs(opts)
|
251 |
+
adj1, adj2, adj3 = train_graph
|
252 |
+
|
253 |
+
|
254 |
+
# Main Training and Testing Loop
|
255 |
+
for epoch in range(resume_epoch, nEpochs):
|
256 |
+
start_time = timeit.default_timer()
|
257 |
+
|
258 |
+
if epoch % p['epoch_size'] == p['epoch_size'] - 1:
|
259 |
+
lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9)
|
260 |
+
optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
|
261 |
+
writer.add_scalar('data/lr_', lr_, epoch)
|
262 |
+
print('(poly lr policy) learning rate: ', lr_)
|
263 |
+
|
264 |
+
net.train()
|
265 |
+
for ii, sample_batched in enumerate(trainloader):
|
266 |
+
|
267 |
+
inputs, labels = sample_batched['image'], sample_batched['label']
|
268 |
+
# Forward-Backward of the mini-batch
|
269 |
+
inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
|
270 |
+
global_step += inputs.data.shape[0]
|
271 |
+
|
272 |
+
if gpu_id >= 0:
|
273 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
274 |
+
|
275 |
+
outputs = net.forward(inputs, adj1, adj3, adj2)
|
276 |
+
|
277 |
+
loss = criterion(outputs, labels, batch_average=True)
|
278 |
+
running_loss_tr += loss.item()
|
279 |
+
|
280 |
+
# Print stuff
|
281 |
+
if ii % num_img_tr == (num_img_tr - 1):
|
282 |
+
running_loss_tr = running_loss_tr / num_img_tr
|
283 |
+
writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
|
284 |
+
print('[Epoch: %d, numImages: %5d]' % (epoch, ii * p['trainBatch'] + inputs.data.shape[0]))
|
285 |
+
print('Loss: %f' % running_loss_tr)
|
286 |
+
running_loss_tr = 0
|
287 |
+
stop_time = timeit.default_timer()
|
288 |
+
print("Execution time: " + str(stop_time - start_time) + "\n")
|
289 |
+
|
290 |
+
# Backward the averaged gradient
|
291 |
+
loss /= p['nAveGrad']
|
292 |
+
loss.backward()
|
293 |
+
aveGrad += 1
|
294 |
+
|
295 |
+
# Update the weights once in p['nAveGrad'] forward passes
|
296 |
+
if aveGrad % p['nAveGrad'] == 0:
|
297 |
+
writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)
|
298 |
+
optimizer.step()
|
299 |
+
optimizer.zero_grad()
|
300 |
+
aveGrad = 0
|
301 |
+
|
302 |
+
# Show 10 * 3 images results each epoch
|
303 |
+
if ii % (num_img_tr // 10) == 0:
|
304 |
+
grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
|
305 |
+
writer.add_image('Image', grid_image, global_step)
|
306 |
+
grid_image = make_grid(util.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False,
|
307 |
+
range=(0, 255))
|
308 |
+
writer.add_image('Predicted label', grid_image, global_step)
|
309 |
+
grid_image = make_grid(util.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255))
|
310 |
+
writer.add_image('Groundtruth label', grid_image, global_step)
|
311 |
+
print('loss is ', loss.cpu().item(), flush=True)
|
312 |
+
|
313 |
+
# Save the model
|
314 |
+
if (epoch % snapshot) == snapshot - 1:
|
315 |
+
torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
|
316 |
+
print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')))
|
317 |
+
|
318 |
+
torch.cuda.empty_cache()
|
319 |
+
|
320 |
+
# One testing epoch
|
321 |
+
if useTest and epoch % nTestInterval == (nTestInterval - 1):
|
322 |
+
val_cihp(net_,testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph,
|
323 |
+
epoch=epoch,writer=writer,criterion=criterion, classes=opts.classes)
|
324 |
+
torch.cuda.empty_cache()
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
if __name__ == '__main__':
|
330 |
+
opts = get_parser()
|
331 |
+
main(opts)
|
TryYours-Virtual-Try-On/Graphonomy-master/exp/universal/pascal_atr_cihp_uni.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import timeit
|
3 |
+
from datetime import datetime
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import glob
|
7 |
+
import numpy as np
|
8 |
+
from collections import OrderedDict
|
9 |
+
sys.path.append('./')
|
10 |
+
sys.path.append('./networks/')
|
11 |
+
# PyTorch includes
|
12 |
+
import torch
|
13 |
+
from torch.autograd import Variable
|
14 |
+
import torch.optim as optim
|
15 |
+
from torchvision import transforms
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torchvision.utils import make_grid
|
18 |
+
import random
|
19 |
+
|
20 |
+
# Tensorboard include
|
21 |
+
from tensorboardX import SummaryWriter
|
22 |
+
|
23 |
+
# Custom includes
|
24 |
+
from dataloaders import pascal, cihp_pascal_atr
|
25 |
+
from utils import get_iou_from_list
|
26 |
+
from utils import util as ut
|
27 |
+
from networks import deeplab_xception_universal, graph
|
28 |
+
from dataloaders import custom_transforms as tr
|
29 |
+
from utils import sampler as sam
|
30 |
+
#
|
31 |
+
import argparse
|
32 |
+
|
33 |
+
'''
|
34 |
+
source is cihp
|
35 |
+
target is pascal
|
36 |
+
'''
|
37 |
+
|
38 |
+
gpu_id = 1
|
39 |
+
# print('Using GPU: {} '.format(gpu_id))
|
40 |
+
|
41 |
+
# nEpochs = 100 # Number of epochs for training
|
42 |
+
resume_epoch = 0 # Default is 0, change if want to resume
|
43 |
+
|
44 |
+
def flip(x, dim):
|
45 |
+
indices = [slice(None)] * x.dim()
|
46 |
+
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
|
47 |
+
dtype=torch.long, device=x.device)
|
48 |
+
return x[tuple(indices)]
|
49 |
+
|
50 |
+
def flip_cihp(tail_list):
|
51 |
+
'''
|
52 |
+
|
53 |
+
:param tail_list: tail_list size is 1 x n_class x h x w
|
54 |
+
:return:
|
55 |
+
'''
|
56 |
+
# tail_list = tail_list[0]
|
57 |
+
tail_list_rev = [None] * 20
|
58 |
+
for xx in range(14):
|
59 |
+
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
|
60 |
+
tail_list_rev[14] = tail_list[15].unsqueeze(0)
|
61 |
+
tail_list_rev[15] = tail_list[14].unsqueeze(0)
|
62 |
+
tail_list_rev[16] = tail_list[17].unsqueeze(0)
|
63 |
+
tail_list_rev[17] = tail_list[16].unsqueeze(0)
|
64 |
+
tail_list_rev[18] = tail_list[19].unsqueeze(0)
|
65 |
+
tail_list_rev[19] = tail_list[18].unsqueeze(0)
|
66 |
+
return torch.cat(tail_list_rev,dim=0)
|
67 |
+
|
68 |
+
def get_parser():
|
69 |
+
'''argparse begin'''
|
70 |
+
parser = argparse.ArgumentParser()
|
71 |
+
LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
|
72 |
+
|
73 |
+
parser.add_argument('--epochs', default=100, type=int)
|
74 |
+
parser.add_argument('--batch', default=16, type=int)
|
75 |
+
parser.add_argument('--lr', default=1e-7, type=float)
|
76 |
+
parser.add_argument('--numworker',default=12,type=int)
|
77 |
+
# parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
|
78 |
+
parser.add_argument('--step', default=10, type=int)
|
79 |
+
# parser.add_argument('--loadmodel',default=None,type=str)
|
80 |
+
parser.add_argument('--classes', default=7, type=int)
|
81 |
+
parser.add_argument('--testepoch', default=10, type=int)
|
82 |
+
parser.add_argument('--loadmodel',default='',type=str)
|
83 |
+
parser.add_argument('--pretrainedModel', default='', type=str)
|
84 |
+
parser.add_argument('--hidden_layers',default=128,type=int)
|
85 |
+
parser.add_argument('--gpus',default=4, type=int)
|
86 |
+
parser.add_argument('--testInterval', default=5, type=int)
|
87 |
+
opts = parser.parse_args()
|
88 |
+
return opts
|
89 |
+
|
90 |
+
def get_graphs(opts):
|
91 |
+
'''source is pascal; target is cihp; middle is atr'''
|
92 |
+
# target 1
|
93 |
+
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
|
94 |
+
adj1_ = Variable(torch.from_numpy(cihp_adj).float())
|
95 |
+
adj1 = adj1_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 20).cuda()
|
96 |
+
adj1_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20)
|
97 |
+
#source 2
|
98 |
+
adj2_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
|
99 |
+
adj2 = adj2_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 7).cuda()
|
100 |
+
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7)
|
101 |
+
# s to target 3
|
102 |
+
adj3_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
|
103 |
+
adj3 = adj3_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20).transpose(2,3).cuda()
|
104 |
+
adj3_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).transpose(2,3)
|
105 |
+
# middle 4
|
106 |
+
atr_adj = graph.preprocess_adj(graph.atr_graph)
|
107 |
+
adj4_ = Variable(torch.from_numpy(atr_adj).float())
|
108 |
+
adj4 = adj4_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 18, 18).cuda()
|
109 |
+
adj4_test = adj4_.unsqueeze(0).unsqueeze(0).expand(1, 1, 18, 18)
|
110 |
+
# source to middle 5
|
111 |
+
adj5_ = torch.from_numpy(graph.pascal2atr_nlp_adj).float()
|
112 |
+
adj5 = adj5_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 18).cuda()
|
113 |
+
adj5_test = adj5_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 18)
|
114 |
+
# target to middle 6
|
115 |
+
adj6_ = torch.from_numpy(graph.cihp2atr_nlp_adj).float()
|
116 |
+
adj6 = adj6_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 18).cuda()
|
117 |
+
adj6_test = adj6_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 18)
|
118 |
+
train_graph = [adj1, adj2, adj3, adj4, adj5, adj6]
|
119 |
+
test_graph = [adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test]
|
120 |
+
return train_graph, test_graph
|
121 |
+
|
122 |
+
|
123 |
+
def main(opts):
|
124 |
+
# Set parameters
|
125 |
+
p = OrderedDict() # Parameters to include in report
|
126 |
+
p['trainBatch'] = opts.batch # Training batch size
|
127 |
+
testBatch = 1 # Testing batch size
|
128 |
+
useTest = True # See evolution of the test set when training
|
129 |
+
nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs
|
130 |
+
snapshot = 1 # Store a model every snapshot epochs
|
131 |
+
p['nAveGrad'] = 1 # Average the gradient of several iterations
|
132 |
+
p['lr'] = opts.lr # Learning rate
|
133 |
+
p['wd'] = 5e-4 # Weight decay
|
134 |
+
p['momentum'] = 0.9 # Momentum
|
135 |
+
p['epoch_size'] = opts.step # How many epochs to change learning rate
|
136 |
+
p['num_workers'] = opts.numworker
|
137 |
+
model_path = opts.pretrainedModel
|
138 |
+
backbone = 'xception' # Use xception or resnet as feature extractor
|
139 |
+
nEpochs = opts.epochs
|
140 |
+
|
141 |
+
max_id = 0
|
142 |
+
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
143 |
+
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
|
144 |
+
runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
|
145 |
+
for r in runs:
|
146 |
+
run_id = int(r.split('_')[-1])
|
147 |
+
if run_id >= max_id:
|
148 |
+
max_id = run_id + 1
|
149 |
+
# run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
|
150 |
+
save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(max_id))
|
151 |
+
|
152 |
+
# Network definition
|
153 |
+
if backbone == 'xception':
|
154 |
+
net_ = deeplab_xception_universal.deeplab_xception_end2end_3d(n_classes=20, os=16,
|
155 |
+
hidden_layers=opts.hidden_layers,
|
156 |
+
source_classes=7,
|
157 |
+
middle_classes=18, )
|
158 |
+
elif backbone == 'resnet':
|
159 |
+
# net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
|
160 |
+
raise NotImplementedError
|
161 |
+
else:
|
162 |
+
raise NotImplementedError
|
163 |
+
|
164 |
+
modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S')
|
165 |
+
criterion = ut.cross_entropy2d
|
166 |
+
|
167 |
+
if gpu_id >= 0:
|
168 |
+
# torch.cuda.set_device(device=gpu_id)
|
169 |
+
net_.cuda()
|
170 |
+
|
171 |
+
# net load weights
|
172 |
+
if not model_path == '':
|
173 |
+
x = torch.load(model_path)
|
174 |
+
net_.load_state_dict_new(x)
|
175 |
+
print('load pretrainedModel.')
|
176 |
+
else:
|
177 |
+
print('no pretrainedModel.')
|
178 |
+
|
179 |
+
if not opts.loadmodel =='':
|
180 |
+
x = torch.load(opts.loadmodel)
|
181 |
+
net_.load_source_model(x)
|
182 |
+
print('load model:' ,opts.loadmodel)
|
183 |
+
else:
|
184 |
+
print('no trained model load !!!!!!!!')
|
185 |
+
|
186 |
+
log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
|
187 |
+
writer = SummaryWriter(log_dir=log_dir)
|
188 |
+
writer.add_text('load model',opts.loadmodel,1)
|
189 |
+
writer.add_text('setting',sys.argv[0],1)
|
190 |
+
|
191 |
+
# Use the following optimizer
|
192 |
+
optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
|
193 |
+
|
194 |
+
composed_transforms_tr = transforms.Compose([
|
195 |
+
tr.RandomSized_new(512),
|
196 |
+
tr.Normalize_xception_tf(),
|
197 |
+
tr.ToTensor_()])
|
198 |
+
|
199 |
+
composed_transforms_ts = transforms.Compose([
|
200 |
+
tr.Normalize_xception_tf(),
|
201 |
+
tr.ToTensor_()])
|
202 |
+
|
203 |
+
composed_transforms_ts_flip = transforms.Compose([
|
204 |
+
tr.HorizontalFlip(),
|
205 |
+
tr.Normalize_xception_tf(),
|
206 |
+
tr.ToTensor_()])
|
207 |
+
|
208 |
+
all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
|
209 |
+
voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
|
210 |
+
voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
|
211 |
+
|
212 |
+
num_cihp,num_pascal,num_atr = all_train.get_class_num()
|
213 |
+
ss = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch)
|
214 |
+
# balance datasets based pascal
|
215 |
+
ss_balanced = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch, balance_id=1)
|
216 |
+
|
217 |
+
trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'],
|
218 |
+
sampler=ss, drop_last=True)
|
219 |
+
trainloader_balanced = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'],
|
220 |
+
sampler=ss_balanced, drop_last=True)
|
221 |
+
testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
|
222 |
+
testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
|
223 |
+
|
224 |
+
num_img_tr = len(trainloader)
|
225 |
+
num_img_balanced = len(trainloader_balanced)
|
226 |
+
num_img_ts = len(testloader)
|
227 |
+
running_loss_tr = 0.0
|
228 |
+
running_loss_tr_atr = 0.0
|
229 |
+
running_loss_ts = 0.0
|
230 |
+
aveGrad = 0
|
231 |
+
global_step = 0
|
232 |
+
print("Training Network")
|
233 |
+
net = torch.nn.DataParallel(net_)
|
234 |
+
|
235 |
+
id_list = torch.LongTensor(range(opts.batch))
|
236 |
+
pascal_iter = int(num_img_tr//opts.batch)
|
237 |
+
|
238 |
+
# Get graphs
|
239 |
+
train_graph, test_graph = get_graphs(opts)
|
240 |
+
adj1, adj2, adj3, adj4, adj5, adj6 = train_graph
|
241 |
+
adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph
|
242 |
+
|
243 |
+
# Main Training and Testing Loop
|
244 |
+
for epoch in range(resume_epoch, int(1.5*nEpochs)):
|
245 |
+
start_time = timeit.default_timer()
|
246 |
+
|
247 |
+
if epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch<nEpochs:
|
248 |
+
lr_ = ut.lr_poly(p['lr'], epoch, nEpochs, 0.9)
|
249 |
+
optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
|
250 |
+
print('(poly lr policy) learning rate: ', lr_)
|
251 |
+
writer.add_scalar('data/lr_',lr_,epoch)
|
252 |
+
elif epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch > nEpochs:
|
253 |
+
lr_ = ut.lr_poly(p['lr'], epoch-nEpochs, int(0.5*nEpochs), 0.9)
|
254 |
+
optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
|
255 |
+
print('(poly lr policy) learning rate: ', lr_)
|
256 |
+
writer.add_scalar('data/lr_', lr_, epoch)
|
257 |
+
|
258 |
+
net_.train()
|
259 |
+
if epoch < nEpochs:
|
260 |
+
for ii, sample_batched in enumerate(trainloader):
|
261 |
+
inputs, labels = sample_batched['image'], sample_batched['label']
|
262 |
+
dataset_lbl = sample_batched['pascal'][0].item()
|
263 |
+
# Forward-Backward of the mini-batch
|
264 |
+
inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
|
265 |
+
global_step += 1
|
266 |
+
|
267 |
+
if gpu_id >= 0:
|
268 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
269 |
+
|
270 |
+
if dataset_lbl == 0:
|
271 |
+
# 0 is cihp -- target
|
272 |
+
_, outputs,_ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1, adj2_source=adj2,
|
273 |
+
adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2,3), adj4_middle=adj4,adj5_transfer_s2m=adj5.transpose(2, 3),
|
274 |
+
adj6_transfer_t2m=adj6.transpose(2, 3),adj5_transfer_m2s=adj5,adj6_transfer_m2t=adj6,)
|
275 |
+
elif dataset_lbl == 1:
|
276 |
+
# pascal is source
|
277 |
+
outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1,
|
278 |
+
adj2_source=adj2,
|
279 |
+
adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
|
280 |
+
adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
|
281 |
+
adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
|
282 |
+
adj6_transfer_m2t=adj6, )
|
283 |
+
else:
|
284 |
+
# atr
|
285 |
+
_, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1,
|
286 |
+
adj2_source=adj2,
|
287 |
+
adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
|
288 |
+
adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
|
289 |
+
adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
|
290 |
+
adj6_transfer_m2t=adj6, )
|
291 |
+
# print(sample_batched['pascal'])
|
292 |
+
# print(outputs.size(),)
|
293 |
+
# print(labels)
|
294 |
+
loss = criterion(outputs, labels, batch_average=True)
|
295 |
+
running_loss_tr += loss.item()
|
296 |
+
|
297 |
+
# Print stuff
|
298 |
+
if ii % num_img_tr == (num_img_tr - 1):
|
299 |
+
running_loss_tr = running_loss_tr / num_img_tr
|
300 |
+
writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
|
301 |
+
print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
|
302 |
+
print('Loss: %f' % running_loss_tr)
|
303 |
+
running_loss_tr = 0
|
304 |
+
stop_time = timeit.default_timer()
|
305 |
+
print("Execution time: " + str(stop_time - start_time) + "\n")
|
306 |
+
|
307 |
+
# Backward the averaged gradient
|
308 |
+
loss /= p['nAveGrad']
|
309 |
+
loss.backward()
|
310 |
+
aveGrad += 1
|
311 |
+
|
312 |
+
# Update the weights once in p['nAveGrad'] forward passes
|
313 |
+
if aveGrad % p['nAveGrad'] == 0:
|
314 |
+
writer.add_scalar('data/total_loss_iter', loss.item(), global_step)
|
315 |
+
if dataset_lbl == 0:
|
316 |
+
writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
|
317 |
+
if dataset_lbl == 1:
|
318 |
+
writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
|
319 |
+
if dataset_lbl == 2:
|
320 |
+
writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)
|
321 |
+
optimizer.step()
|
322 |
+
optimizer.zero_grad()
|
323 |
+
# optimizer_gcn.step()
|
324 |
+
# optimizer_gcn.zero_grad()
|
325 |
+
aveGrad = 0
|
326 |
+
|
327 |
+
# Show 10 * 3 images results each epoch
|
328 |
+
if ii % (num_img_tr // 10) == 0:
|
329 |
+
grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
|
330 |
+
writer.add_image('Image', grid_image, global_step)
|
331 |
+
grid_image = make_grid(ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False,
|
332 |
+
range=(0, 255))
|
333 |
+
writer.add_image('Predicted label', grid_image, global_step)
|
334 |
+
grid_image = make_grid(ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255))
|
335 |
+
writer.add_image('Groundtruth label', grid_image, global_step)
|
336 |
+
|
337 |
+
print('loss is ',loss.cpu().item(),flush=True)
|
338 |
+
else:
|
339 |
+
# Balanced the number of datasets
|
340 |
+
for ii, sample_batched in enumerate(trainloader_balanced):
|
341 |
+
inputs, labels = sample_batched['image'], sample_batched['label']
|
342 |
+
dataset_lbl = sample_batched['pascal'][0].item()
|
343 |
+
# Forward-Backward of the mini-batch
|
344 |
+
inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
|
345 |
+
global_step += 1
|
346 |
+
|
347 |
+
if gpu_id >= 0:
|
348 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
349 |
+
|
350 |
+
if dataset_lbl == 0:
|
351 |
+
# 0 is cihp -- target
|
352 |
+
_, outputs, _ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1,
|
353 |
+
adj2_source=adj2,
|
354 |
+
adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
|
355 |
+
adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
|
356 |
+
adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
|
357 |
+
adj6_transfer_m2t=adj6, )
|
358 |
+
elif dataset_lbl == 1:
|
359 |
+
# pascal is source
|
360 |
+
outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1,
|
361 |
+
adj2_source=adj2,
|
362 |
+
adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
|
363 |
+
adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
|
364 |
+
adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
|
365 |
+
adj6_transfer_m2t=adj6, )
|
366 |
+
else:
|
367 |
+
# atr
|
368 |
+
_, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1,
|
369 |
+
adj2_source=adj2,
|
370 |
+
adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
|
371 |
+
adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
|
372 |
+
adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
|
373 |
+
adj6_transfer_m2t=adj6, )
|
374 |
+
# print(sample_batched['pascal'])
|
375 |
+
# print(outputs.size(),)
|
376 |
+
# print(labels)
|
377 |
+
loss = criterion(outputs, labels, batch_average=True)
|
378 |
+
running_loss_tr += loss.item()
|
379 |
+
|
380 |
+
# Print stuff
|
381 |
+
if ii % num_img_balanced == (num_img_balanced - 1):
|
382 |
+
running_loss_tr = running_loss_tr / num_img_balanced
|
383 |
+
writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
|
384 |
+
print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
|
385 |
+
print('Loss: %f' % running_loss_tr)
|
386 |
+
running_loss_tr = 0
|
387 |
+
stop_time = timeit.default_timer()
|
388 |
+
print("Execution time: " + str(stop_time - start_time) + "\n")
|
389 |
+
|
390 |
+
# Backward the averaged gradient
|
391 |
+
loss /= p['nAveGrad']
|
392 |
+
loss.backward()
|
393 |
+
aveGrad += 1
|
394 |
+
|
395 |
+
# Update the weights once in p['nAveGrad'] forward passes
|
396 |
+
if aveGrad % p['nAveGrad'] == 0:
|
397 |
+
writer.add_scalar('data/total_loss_iter', loss.item(), global_step)
|
398 |
+
if dataset_lbl == 0:
|
399 |
+
writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
|
400 |
+
if dataset_lbl == 1:
|
401 |
+
writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
|
402 |
+
if dataset_lbl == 2:
|
403 |
+
writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)
|
404 |
+
optimizer.step()
|
405 |
+
optimizer.zero_grad()
|
406 |
+
|
407 |
+
aveGrad = 0
|
408 |
+
|
409 |
+
# Show 10 * 3 images results each epoch
|
410 |
+
if ii % (num_img_balanced // 10) == 0:
|
411 |
+
grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
|
412 |
+
writer.add_image('Image', grid_image, global_step)
|
413 |
+
grid_image = make_grid(
|
414 |
+
ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3,
|
415 |
+
normalize=False,
|
416 |
+
range=(0, 255))
|
417 |
+
writer.add_image('Predicted label', grid_image, global_step)
|
418 |
+
grid_image = make_grid(
|
419 |
+
ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3,
|
420 |
+
normalize=False, range=(0, 255))
|
421 |
+
writer.add_image('Groundtruth label', grid_image, global_step)
|
422 |
+
|
423 |
+
print('loss is ', loss.cpu().item(), flush=True)
|
424 |
+
|
425 |
+
# Save the model
|
426 |
+
if (epoch % snapshot) == snapshot - 1:
|
427 |
+
torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
|
428 |
+
print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')))
|
429 |
+
|
430 |
+
# One testing epoch
|
431 |
+
if useTest and epoch % nTestInterval == (nTestInterval - 1):
|
432 |
+
val_pascal(net_=net_, testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph,
|
433 |
+
criterion=criterion, epoch=epoch, writer=writer)
|
434 |
+
|
435 |
+
|
436 |
+
def val_pascal(net_, testloader, testloader_flip, test_graph, criterion, epoch, writer, classes=7):
|
437 |
+
running_loss_ts = 0.0
|
438 |
+
miou = 0
|
439 |
+
adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph
|
440 |
+
num_img_ts = len(testloader)
|
441 |
+
net_.eval()
|
442 |
+
pred_list = []
|
443 |
+
label_list = []
|
444 |
+
for ii, sample_batched in enumerate(zip(testloader, testloader_flip)):
|
445 |
+
# print(ii)
|
446 |
+
inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
|
447 |
+
inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
|
448 |
+
inputs = torch.cat((inputs, inputs_f), dim=0)
|
449 |
+
# Forward pass of the mini-batch
|
450 |
+
inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
|
451 |
+
|
452 |
+
with torch.no_grad():
|
453 |
+
if gpu_id >= 0:
|
454 |
+
inputs, labels = inputs.cuda(), labels.cuda()
|
455 |
+
outputs, _, _ = net_.forward(inputs, input_target=None, input_middle=None,
|
456 |
+
adj1_target=adj1_test.cuda(),
|
457 |
+
adj2_source=adj2_test.cuda(),
|
458 |
+
adj3_transfer_s2t=adj3_test.cuda(),
|
459 |
+
adj3_transfer_t2s=adj3_test.transpose(2, 3).cuda(),
|
460 |
+
adj4_middle=adj4_test.cuda(),
|
461 |
+
adj5_transfer_s2m=adj5_test.transpose(2, 3).cuda(),
|
462 |
+
adj6_transfer_t2m=adj6_test.transpose(2, 3).cuda(),
|
463 |
+
adj5_transfer_m2s=adj5_test.cuda(),
|
464 |
+
adj6_transfer_m2t=adj6_test.cuda(), )
|
465 |
+
# pdb.set_trace()
|
466 |
+
outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2
|
467 |
+
outputs = outputs.unsqueeze(0)
|
468 |
+
predictions = torch.max(outputs, 1)[1]
|
469 |
+
pred_list.append(predictions.cpu())
|
470 |
+
label_list.append(labels.squeeze(1).cpu())
|
471 |
+
loss = criterion(outputs, labels, batch_average=True)
|
472 |
+
running_loss_ts += loss.item()
|
473 |
+
|
474 |
+
# total_iou += utils.get_iou(predictions, labels)
|
475 |
+
|
476 |
+
# Print stuff
|
477 |
+
if ii % num_img_ts == num_img_ts - 1:
|
478 |
+
# if ii == 10:
|
479 |
+
miou = get_iou_from_list(pred_list, label_list, n_cls=classes)
|
480 |
+
running_loss_ts = running_loss_ts / num_img_ts
|
481 |
+
|
482 |
+
print('Validation:')
|
483 |
+
print('[Epoch: %d, numImages: %5d]' % (epoch, ii * 1 + inputs.data.shape[0]))
|
484 |
+
writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch)
|
485 |
+
writer.add_scalar('data/test_miour', miou, epoch)
|
486 |
+
print('Loss: %f' % running_loss_ts)
|
487 |
+
print('MIoU: %f\n' % miou)
|
488 |
+
# return miou
|
489 |
+
|
490 |
+
|
491 |
+
if __name__ == '__main__':
|
492 |
+
opts = get_parser()
|
493 |
+
main(opts)
|
TryYours-Virtual-Try-On/Graphonomy-master/inference.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python exp/inference/inference.py --loadmodel ./data/pretrained_model/inference.pth --img_path ./img/messi.jpg --output_path ./img/ --output_name /output_file_name
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .deeplab_xception import *
|
2 |
+
from .deeplab_xception_transfer import *
|
3 |
+
from .deeplab_xception_universal import *
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (282 Bytes). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (258 Bytes). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-310.pyc
ADDED
Binary file (17 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-39.pyc
ADDED
Binary file (17.1 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-39.pyc
ADDED
Binary file (15.4 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-39.pyc
ADDED
Binary file (19.1 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-310.pyc
ADDED
Binary file (15.5 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-39.pyc
ADDED
Binary file (20.3 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-310.pyc
ADDED
Binary file (8.3 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-39.pyc
ADDED
Binary file (8.46 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-310.pyc
ADDED
Binary file (9.09 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-39.pyc
ADDED
Binary file (9 kB). View file
|
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
class SeparableConv2d(nn.Module):
|
10 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False):
|
11 |
+
super(SeparableConv2d, self).__init__()
|
12 |
+
|
13 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
|
14 |
+
groups=inplanes, bias=bias)
|
15 |
+
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.conv1(x)
|
19 |
+
x = self.pointwise(x)
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def fixed_padding(inputs, kernel_size, rate):
|
24 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
|
25 |
+
pad_total = kernel_size_effective - 1
|
26 |
+
pad_beg = pad_total // 2
|
27 |
+
pad_end = pad_total - pad_beg
|
28 |
+
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
|
29 |
+
return padded_inputs
|
30 |
+
|
31 |
+
|
32 |
+
class SeparableConv2d_aspp(nn.Module):
|
33 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
|
34 |
+
super(SeparableConv2d_aspp, self).__init__()
|
35 |
+
|
36 |
+
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
|
37 |
+
groups=inplanes, bias=bias)
|
38 |
+
self.depthwise_bn = nn.BatchNorm2d(inplanes)
|
39 |
+
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
40 |
+
self.pointwise_bn = nn.BatchNorm2d(planes)
|
41 |
+
self.relu = nn.ReLU()
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
# x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
|
45 |
+
x = self.depthwise(x)
|
46 |
+
x = self.depthwise_bn(x)
|
47 |
+
x = self.relu(x)
|
48 |
+
x = self.pointwise(x)
|
49 |
+
x = self.pointwise_bn(x)
|
50 |
+
x = self.relu(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
class Decoder_module(nn.Module):
|
54 |
+
def __init__(self, inplanes, planes, rate=1):
|
55 |
+
super(Decoder_module, self).__init__()
|
56 |
+
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = self.atrous_convolution(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
class ASPP_module(nn.Module):
|
63 |
+
def __init__(self, inplanes, planes, rate):
|
64 |
+
super(ASPP_module, self).__init__()
|
65 |
+
if rate == 1:
|
66 |
+
raise RuntimeError()
|
67 |
+
else:
|
68 |
+
kernel_size = 3
|
69 |
+
padding = rate
|
70 |
+
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,
|
71 |
+
padding=padding)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
x = self.atrous_convolution(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
class ASPP_module_rate0(nn.Module):
|
78 |
+
def __init__(self, inplanes, planes, rate=1):
|
79 |
+
super(ASPP_module_rate0, self).__init__()
|
80 |
+
if rate == 1:
|
81 |
+
kernel_size = 1
|
82 |
+
padding = 0
|
83 |
+
self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
|
84 |
+
stride=1, padding=padding, dilation=rate, bias=False)
|
85 |
+
self.bn = nn.BatchNorm2d(planes, eps=1e-5, affine=True)
|
86 |
+
self.relu = nn.ReLU()
|
87 |
+
else:
|
88 |
+
raise RuntimeError()
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.atrous_convolution(x)
|
92 |
+
x = self.bn(x)
|
93 |
+
return self.relu(x)
|
94 |
+
|
95 |
+
class SeparableConv2d_same(nn.Module):
|
96 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
|
97 |
+
super(SeparableConv2d_same, self).__init__()
|
98 |
+
|
99 |
+
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
|
100 |
+
groups=inplanes, bias=bias)
|
101 |
+
self.depthwise_bn = nn.BatchNorm2d(inplanes)
|
102 |
+
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
103 |
+
self.pointwise_bn = nn.BatchNorm2d(planes)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
|
107 |
+
x = self.depthwise(x)
|
108 |
+
x = self.depthwise_bn(x)
|
109 |
+
x = self.pointwise(x)
|
110 |
+
x = self.pointwise_bn(x)
|
111 |
+
return x
|
112 |
+
|
113 |
+
class Block(nn.Module):
|
114 |
+
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
|
115 |
+
super(Block, self).__init__()
|
116 |
+
|
117 |
+
if planes != inplanes or stride != 1:
|
118 |
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False)
|
119 |
+
if is_last:
|
120 |
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False)
|
121 |
+
self.skipbn = nn.BatchNorm2d(planes)
|
122 |
+
else:
|
123 |
+
self.skip = None
|
124 |
+
|
125 |
+
self.relu = nn.ReLU(inplace=True)
|
126 |
+
rep = []
|
127 |
+
|
128 |
+
filters = inplanes
|
129 |
+
if grow_first:
|
130 |
+
rep.append(self.relu)
|
131 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
132 |
+
# rep.append(nn.BatchNorm2d(planes))
|
133 |
+
filters = planes
|
134 |
+
|
135 |
+
for i in range(reps - 1):
|
136 |
+
rep.append(self.relu)
|
137 |
+
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
|
138 |
+
# rep.append(nn.BatchNorm2d(filters))
|
139 |
+
|
140 |
+
if not grow_first:
|
141 |
+
rep.append(self.relu)
|
142 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
143 |
+
# rep.append(nn.BatchNorm2d(planes))
|
144 |
+
|
145 |
+
if not start_with_relu:
|
146 |
+
rep = rep[1:]
|
147 |
+
|
148 |
+
if stride != 1:
|
149 |
+
rep.append(self.relu)
|
150 |
+
rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation))
|
151 |
+
|
152 |
+
if is_last:
|
153 |
+
rep.append(self.relu)
|
154 |
+
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation))
|
155 |
+
|
156 |
+
|
157 |
+
self.rep = nn.Sequential(*rep)
|
158 |
+
|
159 |
+
def forward(self, inp):
|
160 |
+
x = self.rep(inp)
|
161 |
+
|
162 |
+
if self.skip is not None:
|
163 |
+
skip = self.skip(inp)
|
164 |
+
skip = self.skipbn(skip)
|
165 |
+
else:
|
166 |
+
skip = inp
|
167 |
+
# print(x.size(),skip.size())
|
168 |
+
x += skip
|
169 |
+
|
170 |
+
return x
|
171 |
+
|
172 |
+
class Block2(nn.Module):
|
173 |
+
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
|
174 |
+
super(Block2, self).__init__()
|
175 |
+
|
176 |
+
if planes != inplanes or stride != 1:
|
177 |
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
|
178 |
+
self.skipbn = nn.BatchNorm2d(planes)
|
179 |
+
else:
|
180 |
+
self.skip = None
|
181 |
+
|
182 |
+
self.relu = nn.ReLU(inplace=True)
|
183 |
+
rep = []
|
184 |
+
|
185 |
+
filters = inplanes
|
186 |
+
if grow_first:
|
187 |
+
rep.append(self.relu)
|
188 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
189 |
+
# rep.append(nn.BatchNorm2d(planes))
|
190 |
+
filters = planes
|
191 |
+
|
192 |
+
for i in range(reps - 1):
|
193 |
+
rep.append(self.relu)
|
194 |
+
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
|
195 |
+
# rep.append(nn.BatchNorm2d(filters))
|
196 |
+
|
197 |
+
if not grow_first:
|
198 |
+
rep.append(self.relu)
|
199 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
200 |
+
# rep.append(nn.BatchNorm2d(planes))
|
201 |
+
|
202 |
+
if not start_with_relu:
|
203 |
+
rep = rep[1:]
|
204 |
+
|
205 |
+
if stride != 1:
|
206 |
+
self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)])
|
207 |
+
|
208 |
+
if is_last:
|
209 |
+
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1))
|
210 |
+
|
211 |
+
|
212 |
+
self.rep = nn.Sequential(*rep)
|
213 |
+
|
214 |
+
def forward(self, inp):
|
215 |
+
x = self.rep(inp)
|
216 |
+
low_middle = x.clone()
|
217 |
+
x1 = x
|
218 |
+
x1 = self.block2_lastconv(x1)
|
219 |
+
if self.skip is not None:
|
220 |
+
skip = self.skip(inp)
|
221 |
+
skip = self.skipbn(skip)
|
222 |
+
else:
|
223 |
+
skip = inp
|
224 |
+
|
225 |
+
x1 += skip
|
226 |
+
|
227 |
+
return x1,low_middle
|
228 |
+
|
229 |
+
class Xception(nn.Module):
|
230 |
+
"""
|
231 |
+
Modified Alighed Xception
|
232 |
+
"""
|
233 |
+
def __init__(self, inplanes=3, os=16, pretrained=False):
|
234 |
+
super(Xception, self).__init__()
|
235 |
+
|
236 |
+
if os == 16:
|
237 |
+
entry_block3_stride = 2
|
238 |
+
middle_block_rate = 1
|
239 |
+
exit_block_rates = (1, 2)
|
240 |
+
elif os == 8:
|
241 |
+
entry_block3_stride = 1
|
242 |
+
middle_block_rate = 2
|
243 |
+
exit_block_rates = (2, 4)
|
244 |
+
else:
|
245 |
+
raise NotImplementedError
|
246 |
+
|
247 |
+
|
248 |
+
# Entry flow
|
249 |
+
self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False)
|
250 |
+
self.bn1 = nn.BatchNorm2d(32)
|
251 |
+
self.relu = nn.ReLU(inplace=True)
|
252 |
+
|
253 |
+
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
|
254 |
+
self.bn2 = nn.BatchNorm2d(64)
|
255 |
+
|
256 |
+
self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)
|
257 |
+
self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
|
258 |
+
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True)
|
259 |
+
|
260 |
+
# Middle flow
|
261 |
+
self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
262 |
+
self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
263 |
+
self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
264 |
+
self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
265 |
+
self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
266 |
+
self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
267 |
+
self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
268 |
+
self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
269 |
+
self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
270 |
+
self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
271 |
+
self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
272 |
+
self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
273 |
+
self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
274 |
+
self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
275 |
+
self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
276 |
+
self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
277 |
+
|
278 |
+
# Exit flow
|
279 |
+
self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0],
|
280 |
+
start_with_relu=True, grow_first=False, is_last=True)
|
281 |
+
|
282 |
+
self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
|
283 |
+
# self.bn3 = nn.BatchNorm2d(1536)
|
284 |
+
|
285 |
+
self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
|
286 |
+
# self.bn4 = nn.BatchNorm2d(1536)
|
287 |
+
|
288 |
+
self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
|
289 |
+
# self.bn5 = nn.BatchNorm2d(2048)
|
290 |
+
|
291 |
+
# Init weights
|
292 |
+
# self.__init_weight()
|
293 |
+
|
294 |
+
# Load pretrained model
|
295 |
+
if pretrained:
|
296 |
+
self.__load_xception_pretrained()
|
297 |
+
|
298 |
+
def forward(self, x):
|
299 |
+
# Entry flow
|
300 |
+
x = self.conv1(x)
|
301 |
+
x = self.bn1(x)
|
302 |
+
x = self.relu(x)
|
303 |
+
# print('conv1 ',x.size())
|
304 |
+
x = self.conv2(x)
|
305 |
+
x = self.bn2(x)
|
306 |
+
x = self.relu(x)
|
307 |
+
|
308 |
+
x = self.block1(x)
|
309 |
+
# print('block1',x.size())
|
310 |
+
# low_level_feat = x
|
311 |
+
x,low_level_feat = self.block2(x)
|
312 |
+
# print('block2',x.size())
|
313 |
+
x = self.block3(x)
|
314 |
+
# print('xception block3 ',x.size())
|
315 |
+
|
316 |
+
# Middle flow
|
317 |
+
x = self.block4(x)
|
318 |
+
x = self.block5(x)
|
319 |
+
x = self.block6(x)
|
320 |
+
x = self.block7(x)
|
321 |
+
x = self.block8(x)
|
322 |
+
x = self.block9(x)
|
323 |
+
x = self.block10(x)
|
324 |
+
x = self.block11(x)
|
325 |
+
x = self.block12(x)
|
326 |
+
x = self.block13(x)
|
327 |
+
x = self.block14(x)
|
328 |
+
x = self.block15(x)
|
329 |
+
x = self.block16(x)
|
330 |
+
x = self.block17(x)
|
331 |
+
x = self.block18(x)
|
332 |
+
x = self.block19(x)
|
333 |
+
|
334 |
+
# Exit flow
|
335 |
+
x = self.block20(x)
|
336 |
+
x = self.conv3(x)
|
337 |
+
# x = self.bn3(x)
|
338 |
+
x = self.relu(x)
|
339 |
+
|
340 |
+
x = self.conv4(x)
|
341 |
+
# x = self.bn4(x)
|
342 |
+
x = self.relu(x)
|
343 |
+
|
344 |
+
x = self.conv5(x)
|
345 |
+
# x = self.bn5(x)
|
346 |
+
x = self.relu(x)
|
347 |
+
|
348 |
+
return x, low_level_feat
|
349 |
+
|
350 |
+
def __init_weight(self):
|
351 |
+
for m in self.modules():
|
352 |
+
if isinstance(m, nn.Conv2d):
|
353 |
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
354 |
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
355 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
356 |
+
elif isinstance(m, nn.BatchNorm2d):
|
357 |
+
m.weight.data.fill_(1)
|
358 |
+
m.bias.data.zero_()
|
359 |
+
|
360 |
+
def __load_xception_pretrained(self):
|
361 |
+
pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
|
362 |
+
model_dict = {}
|
363 |
+
state_dict = self.state_dict()
|
364 |
+
|
365 |
+
for k, v in pretrain_dict.items():
|
366 |
+
if k in state_dict:
|
367 |
+
if 'pointwise' in k:
|
368 |
+
v = v.unsqueeze(-1).unsqueeze(-1)
|
369 |
+
if k.startswith('block12'):
|
370 |
+
model_dict[k.replace('block12', 'block20')] = v
|
371 |
+
elif k.startswith('block11'):
|
372 |
+
model_dict[k.replace('block11', 'block12')] = v
|
373 |
+
model_dict[k.replace('block11', 'block13')] = v
|
374 |
+
model_dict[k.replace('block11', 'block14')] = v
|
375 |
+
model_dict[k.replace('block11', 'block15')] = v
|
376 |
+
model_dict[k.replace('block11', 'block16')] = v
|
377 |
+
model_dict[k.replace('block11', 'block17')] = v
|
378 |
+
model_dict[k.replace('block11', 'block18')] = v
|
379 |
+
model_dict[k.replace('block11', 'block19')] = v
|
380 |
+
elif k.startswith('conv3'):
|
381 |
+
model_dict[k] = v
|
382 |
+
elif k.startswith('bn3'):
|
383 |
+
model_dict[k] = v
|
384 |
+
model_dict[k.replace('bn3', 'bn4')] = v
|
385 |
+
elif k.startswith('conv4'):
|
386 |
+
model_dict[k.replace('conv4', 'conv5')] = v
|
387 |
+
elif k.startswith('bn4'):
|
388 |
+
model_dict[k.replace('bn4', 'bn5')] = v
|
389 |
+
else:
|
390 |
+
model_dict[k] = v
|
391 |
+
state_dict.update(model_dict)
|
392 |
+
self.load_state_dict(state_dict)
|
393 |
+
|
394 |
+
class DeepLabv3_plus(nn.Module):
|
395 |
+
def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True):
|
396 |
+
if _print:
|
397 |
+
print("Constructing DeepLabv3+ model...")
|
398 |
+
print("Number of classes: {}".format(n_classes))
|
399 |
+
print("Output stride: {}".format(os))
|
400 |
+
print("Number of Input Channels: {}".format(nInputChannels))
|
401 |
+
super(DeepLabv3_plus, self).__init__()
|
402 |
+
|
403 |
+
# Atrous Conv
|
404 |
+
self.xception_features = Xception(nInputChannels, os, pretrained)
|
405 |
+
|
406 |
+
# ASPP
|
407 |
+
if os == 16:
|
408 |
+
rates = [1, 6, 12, 18]
|
409 |
+
elif os == 8:
|
410 |
+
rates = [1, 12, 24, 36]
|
411 |
+
raise NotImplementedError
|
412 |
+
else:
|
413 |
+
raise NotImplementedError
|
414 |
+
|
415 |
+
self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0])
|
416 |
+
self.aspp2 = ASPP_module(2048, 256, rate=rates[1])
|
417 |
+
self.aspp3 = ASPP_module(2048, 256, rate=rates[2])
|
418 |
+
self.aspp4 = ASPP_module(2048, 256, rate=rates[3])
|
419 |
+
|
420 |
+
self.relu = nn.ReLU()
|
421 |
+
|
422 |
+
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
423 |
+
nn.Conv2d(2048, 256, 1, stride=1, bias=False),
|
424 |
+
nn.BatchNorm2d(256),
|
425 |
+
nn.ReLU()
|
426 |
+
)
|
427 |
+
|
428 |
+
self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False)
|
429 |
+
self.concat_projection_bn1 = nn.BatchNorm2d(256)
|
430 |
+
|
431 |
+
# adopt [1x1, 48] for channel reduction.
|
432 |
+
self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False)
|
433 |
+
self.feature_projection_bn1 = nn.BatchNorm2d(48)
|
434 |
+
|
435 |
+
self.decoder = nn.Sequential(Decoder_module(304, 256),
|
436 |
+
Decoder_module(256, 256)
|
437 |
+
)
|
438 |
+
self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1)
|
439 |
+
|
440 |
+
def forward(self, input):
|
441 |
+
x, low_level_features = self.xception_features(input)
|
442 |
+
# print(x.size())
|
443 |
+
x1 = self.aspp1(x)
|
444 |
+
x2 = self.aspp2(x)
|
445 |
+
x3 = self.aspp3(x)
|
446 |
+
x4 = self.aspp4(x)
|
447 |
+
x5 = self.global_avg_pool(x)
|
448 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
449 |
+
|
450 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
451 |
+
|
452 |
+
x = self.concat_projection_conv1(x)
|
453 |
+
x = self.concat_projection_bn1(x)
|
454 |
+
x = self.relu(x)
|
455 |
+
# print(x.size())
|
456 |
+
|
457 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
458 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
459 |
+
low_level_features = self.relu(low_level_features)
|
460 |
+
|
461 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
462 |
+
# print(low_level_features.size())
|
463 |
+
# print(x.size())
|
464 |
+
x = torch.cat((x, low_level_features), dim=1)
|
465 |
+
x = self.decoder(x)
|
466 |
+
x = self.semantic(x)
|
467 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
468 |
+
|
469 |
+
return x
|
470 |
+
|
471 |
+
def freeze_bn(self):
|
472 |
+
for m in self.xception_features.modules():
|
473 |
+
if isinstance(m, nn.BatchNorm2d):
|
474 |
+
m.eval()
|
475 |
+
|
476 |
+
def freeze_totally_bn(self):
|
477 |
+
for m in self.modules():
|
478 |
+
if isinstance(m, nn.BatchNorm2d):
|
479 |
+
m.eval()
|
480 |
+
|
481 |
+
def freeze_aspp_bn(self):
|
482 |
+
for m in self.aspp1.modules():
|
483 |
+
if isinstance(m, nn.BatchNorm2d):
|
484 |
+
m.eval()
|
485 |
+
for m in self.aspp2.modules():
|
486 |
+
if isinstance(m, nn.BatchNorm2d):
|
487 |
+
m.eval()
|
488 |
+
for m in self.aspp3.modules():
|
489 |
+
if isinstance(m, nn.BatchNorm2d):
|
490 |
+
m.eval()
|
491 |
+
for m in self.aspp4.modules():
|
492 |
+
if isinstance(m, nn.BatchNorm2d):
|
493 |
+
m.eval()
|
494 |
+
|
495 |
+
def learnable_parameters(self):
|
496 |
+
layer_features_BN = []
|
497 |
+
layer_features = []
|
498 |
+
layer_aspp = []
|
499 |
+
layer_projection =[]
|
500 |
+
layer_decoder = []
|
501 |
+
layer_other = []
|
502 |
+
model_para = list(self.named_parameters())
|
503 |
+
for name,para in model_para:
|
504 |
+
if 'xception' in name:
|
505 |
+
if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name:
|
506 |
+
layer_features_BN.append(para)
|
507 |
+
else:
|
508 |
+
layer_features.append(para)
|
509 |
+
# print (name)
|
510 |
+
elif 'aspp' in name:
|
511 |
+
layer_aspp.append(para)
|
512 |
+
elif 'projection' in name:
|
513 |
+
layer_projection.append(para)
|
514 |
+
elif 'decode' in name:
|
515 |
+
layer_decoder.append(para)
|
516 |
+
elif 'global' not in name:
|
517 |
+
layer_other.append(para)
|
518 |
+
return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other
|
519 |
+
|
520 |
+
def get_backbone_para(self):
|
521 |
+
layer_features = []
|
522 |
+
other_features = []
|
523 |
+
model_para = list(self.named_parameters())
|
524 |
+
for name, para in model_para:
|
525 |
+
if 'xception' in name:
|
526 |
+
layer_features.append(para)
|
527 |
+
else:
|
528 |
+
other_features.append(para)
|
529 |
+
|
530 |
+
return layer_features, other_features
|
531 |
+
|
532 |
+
def train_fixbn(self, mode=True, freeze_bn=True, freeze_bn_affine=False):
|
533 |
+
r"""Sets the module in training mode.
|
534 |
+
|
535 |
+
This has any effect only on certain modules. See documentations of
|
536 |
+
particular modules for details of their behaviors in training/evaluation
|
537 |
+
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
|
538 |
+
etc.
|
539 |
+
|
540 |
+
Returns:
|
541 |
+
Module: self
|
542 |
+
"""
|
543 |
+
super(DeepLabv3_plus, self).train(mode)
|
544 |
+
if freeze_bn:
|
545 |
+
print("Freezing Mean/Var of BatchNorm2D.")
|
546 |
+
if freeze_bn_affine:
|
547 |
+
print("Freezing Weight/Bias of BatchNorm2D.")
|
548 |
+
if freeze_bn:
|
549 |
+
for m in self.xception_features.modules():
|
550 |
+
if isinstance(m, nn.BatchNorm2d):
|
551 |
+
m.eval()
|
552 |
+
if freeze_bn_affine:
|
553 |
+
m.weight.requires_grad = False
|
554 |
+
m.bias.requires_grad = False
|
555 |
+
# for m in self.aspp1.modules():
|
556 |
+
# if isinstance(m, nn.BatchNorm2d):
|
557 |
+
# m.eval()
|
558 |
+
# if freeze_bn_affine:
|
559 |
+
# m.weight.requires_grad = False
|
560 |
+
# m.bias.requires_grad = False
|
561 |
+
# for m in self.aspp2.modules():
|
562 |
+
# if isinstance(m, nn.BatchNorm2d):
|
563 |
+
# m.eval()
|
564 |
+
# if freeze_bn_affine:
|
565 |
+
# m.weight.requires_grad = False
|
566 |
+
# m.bias.requires_grad = False
|
567 |
+
# for m in self.aspp3.modules():
|
568 |
+
# if isinstance(m, nn.BatchNorm2d):
|
569 |
+
# m.eval()
|
570 |
+
# if freeze_bn_affine:
|
571 |
+
# m.weight.requires_grad = False
|
572 |
+
# m.bias.requires_grad = False
|
573 |
+
# for m in self.aspp4.modules():
|
574 |
+
# if isinstance(m, nn.BatchNorm2d):
|
575 |
+
# m.eval()
|
576 |
+
# if freeze_bn_affine:
|
577 |
+
# m.weight.requires_grad = False
|
578 |
+
# m.bias.requires_grad = False
|
579 |
+
# for m in self.global_avg_pool.modules():
|
580 |
+
# if isinstance(m, nn.BatchNorm2d):
|
581 |
+
# m.eval()
|
582 |
+
# if freeze_bn_affine:
|
583 |
+
# m.weight.requires_grad = False
|
584 |
+
# m.bias.requires_grad = False
|
585 |
+
# for m in self.concat_projection_bn1.modules():
|
586 |
+
# if isinstance(m, nn.BatchNorm2d):
|
587 |
+
# m.eval()
|
588 |
+
# if freeze_bn_affine:
|
589 |
+
# m.weight.requires_grad = False
|
590 |
+
# m.bias.requires_grad = False
|
591 |
+
# for m in self.feature_projection_bn1.modules():
|
592 |
+
# if isinstance(m, nn.BatchNorm2d):
|
593 |
+
# m.eval()
|
594 |
+
# if freeze_bn_affine:
|
595 |
+
# m.weight.requires_grad = False
|
596 |
+
# m.bias.requires_grad = False
|
597 |
+
|
598 |
+
def __init_weight(self):
|
599 |
+
for m in self.modules():
|
600 |
+
if isinstance(m, nn.Conv2d):
|
601 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
602 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
603 |
+
# torch.nn.init.kaiming_normal_(m.weight)
|
604 |
+
elif isinstance(m, nn.BatchNorm2d):
|
605 |
+
m.weight.data.fill_(1)
|
606 |
+
m.bias.data.zero_()
|
607 |
+
|
608 |
+
def load_state_dict_new(self, state_dict):
|
609 |
+
own_state = self.state_dict()
|
610 |
+
#for name inshop_cos own_state:
|
611 |
+
# print name
|
612 |
+
new_state_dict = OrderedDict()
|
613 |
+
for name, param in state_dict.items():
|
614 |
+
name = name.replace('module.','')
|
615 |
+
new_state_dict[name] = 0
|
616 |
+
if name not in own_state:
|
617 |
+
if 'num_batch' in name:
|
618 |
+
continue
|
619 |
+
print ('unexpected key "{}" in state_dict'
|
620 |
+
.format(name))
|
621 |
+
continue
|
622 |
+
# if isinstance(param, own_state):
|
623 |
+
if isinstance(param, Parameter):
|
624 |
+
# backwards compatibility for serialized parameters
|
625 |
+
param = param.data
|
626 |
+
try:
|
627 |
+
own_state[name].copy_(param)
|
628 |
+
except:
|
629 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
630 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
631 |
+
name, own_state[name].size(), param.size()))
|
632 |
+
continue # i add inshop_cos 2018/02/01
|
633 |
+
# raise
|
634 |
+
# print 'copying %s' %name
|
635 |
+
# if isinstance(param, own_state):
|
636 |
+
# backwards compatibility for serialized parameters
|
637 |
+
own_state[name].copy_(param)
|
638 |
+
# print 'copying %s' %name
|
639 |
+
|
640 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
641 |
+
if len(missing) > 0:
|
642 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
643 |
+
|
644 |
+
|
645 |
+
def get_1x_lr_params(model):
|
646 |
+
"""
|
647 |
+
This generator returns all the parameters of the net except for
|
648 |
+
the last classification layer. Note that for each batchnorm layer,
|
649 |
+
requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
|
650 |
+
any batchnorm parameter
|
651 |
+
"""
|
652 |
+
b = [model.xception_features]
|
653 |
+
for i in range(len(b)):
|
654 |
+
for k in b[i].parameters():
|
655 |
+
if k.requires_grad:
|
656 |
+
yield k
|
657 |
+
|
658 |
+
|
659 |
+
def get_10x_lr_params(model):
|
660 |
+
"""
|
661 |
+
This generator returns all the parameters for the last layer of the net,
|
662 |
+
which does the classification of pixel into classes
|
663 |
+
"""
|
664 |
+
b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
|
665 |
+
for j in range(len(b)):
|
666 |
+
for k in b[j].parameters():
|
667 |
+
if k.requires_grad:
|
668 |
+
yield k
|
669 |
+
|
670 |
+
|
671 |
+
if __name__ == "__main__":
|
672 |
+
model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True)
|
673 |
+
model.eval()
|
674 |
+
image = torch.randn(1, 3, 512, 512)*255
|
675 |
+
with torch.no_grad():
|
676 |
+
output = model.forward(image)
|
677 |
+
print(output.size())
|
678 |
+
# print(output)
|
679 |
+
|
680 |
+
|
681 |
+
|
682 |
+
|
683 |
+
|
684 |
+
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_synBN.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
from collections import OrderedDict
|
8 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback, SynchronizedBatchNorm2d
|
9 |
+
|
10 |
+
|
11 |
+
def fixed_padding(inputs, kernel_size, rate):
|
12 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
|
13 |
+
pad_total = kernel_size_effective - 1
|
14 |
+
pad_beg = pad_total // 2
|
15 |
+
pad_end = pad_total - pad_beg
|
16 |
+
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
|
17 |
+
return padded_inputs
|
18 |
+
|
19 |
+
class SeparableConv2d_aspp(nn.Module):
|
20 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
|
21 |
+
super(SeparableConv2d_aspp, self).__init__()
|
22 |
+
|
23 |
+
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
|
24 |
+
groups=inplanes, bias=bias)
|
25 |
+
self.depthwise_bn = SynchronizedBatchNorm2d(inplanes)
|
26 |
+
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
27 |
+
self.pointwise_bn = SynchronizedBatchNorm2d(planes)
|
28 |
+
self.relu = nn.ReLU()
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
# x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
|
32 |
+
x = self.depthwise(x)
|
33 |
+
x = self.depthwise_bn(x)
|
34 |
+
x = self.relu(x)
|
35 |
+
x = self.pointwise(x)
|
36 |
+
x = self.pointwise_bn(x)
|
37 |
+
x = self.relu(x)
|
38 |
+
return x
|
39 |
+
|
40 |
+
class Decoder_module(nn.Module):
|
41 |
+
def __init__(self, inplanes, planes, rate=1):
|
42 |
+
super(Decoder_module, self).__init__()
|
43 |
+
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.atrous_convolution(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
class ASPP_module(nn.Module):
|
50 |
+
def __init__(self, inplanes, planes, rate):
|
51 |
+
super(ASPP_module, self).__init__()
|
52 |
+
if rate == 1:
|
53 |
+
raise RuntimeError()
|
54 |
+
else:
|
55 |
+
kernel_size = 3
|
56 |
+
padding = rate
|
57 |
+
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,
|
58 |
+
padding=padding)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = self.atrous_convolution(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class ASPP_module_rate0(nn.Module):
|
66 |
+
def __init__(self, inplanes, planes, rate=1):
|
67 |
+
super(ASPP_module_rate0, self).__init__()
|
68 |
+
if rate == 1:
|
69 |
+
kernel_size = 1
|
70 |
+
padding = 0
|
71 |
+
self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
|
72 |
+
stride=1, padding=padding, dilation=rate, bias=False)
|
73 |
+
self.bn = SynchronizedBatchNorm2d(planes, eps=1e-5, affine=True)
|
74 |
+
self.relu = nn.ReLU()
|
75 |
+
else:
|
76 |
+
raise RuntimeError()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = self.atrous_convolution(x)
|
80 |
+
x = self.bn(x)
|
81 |
+
return self.relu(x)
|
82 |
+
|
83 |
+
|
84 |
+
class SeparableConv2d_same(nn.Module):
|
85 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
|
86 |
+
super(SeparableConv2d_same, self).__init__()
|
87 |
+
|
88 |
+
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
|
89 |
+
groups=inplanes, bias=bias)
|
90 |
+
self.depthwise_bn = SynchronizedBatchNorm2d(inplanes)
|
91 |
+
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
92 |
+
self.pointwise_bn = SynchronizedBatchNorm2d(planes)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
|
96 |
+
x = self.depthwise(x)
|
97 |
+
x = self.depthwise_bn(x)
|
98 |
+
x = self.pointwise(x)
|
99 |
+
x = self.pointwise_bn(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class Block(nn.Module):
|
104 |
+
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
|
105 |
+
super(Block, self).__init__()
|
106 |
+
|
107 |
+
if planes != inplanes or stride != 1:
|
108 |
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False)
|
109 |
+
if is_last:
|
110 |
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False)
|
111 |
+
self.skipbn = SynchronizedBatchNorm2d(planes)
|
112 |
+
else:
|
113 |
+
self.skip = None
|
114 |
+
|
115 |
+
self.relu = nn.ReLU(inplace=True)
|
116 |
+
rep = []
|
117 |
+
|
118 |
+
filters = inplanes
|
119 |
+
if grow_first:
|
120 |
+
rep.append(self.relu)
|
121 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
122 |
+
# rep.append(nn.BatchNorm2d(planes))
|
123 |
+
filters = planes
|
124 |
+
|
125 |
+
for i in range(reps - 1):
|
126 |
+
rep.append(self.relu)
|
127 |
+
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
|
128 |
+
# rep.append(nn.BatchNorm2d(filters))
|
129 |
+
|
130 |
+
if not grow_first:
|
131 |
+
rep.append(self.relu)
|
132 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
133 |
+
# rep.append(nn.BatchNorm2d(planes))
|
134 |
+
|
135 |
+
if not start_with_relu:
|
136 |
+
rep = rep[1:]
|
137 |
+
|
138 |
+
if stride != 1:
|
139 |
+
rep.append(self.relu)
|
140 |
+
rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation))
|
141 |
+
|
142 |
+
if is_last:
|
143 |
+
rep.append(self.relu)
|
144 |
+
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation))
|
145 |
+
|
146 |
+
|
147 |
+
self.rep = nn.Sequential(*rep)
|
148 |
+
|
149 |
+
def forward(self, inp):
|
150 |
+
x = self.rep(inp)
|
151 |
+
|
152 |
+
if self.skip is not None:
|
153 |
+
skip = self.skip(inp)
|
154 |
+
skip = self.skipbn(skip)
|
155 |
+
else:
|
156 |
+
skip = inp
|
157 |
+
# print(x.size(),skip.size())
|
158 |
+
x += skip
|
159 |
+
|
160 |
+
return x
|
161 |
+
|
162 |
+
class Block2(nn.Module):
|
163 |
+
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
|
164 |
+
super(Block2, self).__init__()
|
165 |
+
|
166 |
+
if planes != inplanes or stride != 1:
|
167 |
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
|
168 |
+
self.skipbn = SynchronizedBatchNorm2d(planes)
|
169 |
+
else:
|
170 |
+
self.skip = None
|
171 |
+
|
172 |
+
self.relu = nn.ReLU(inplace=True)
|
173 |
+
rep = []
|
174 |
+
|
175 |
+
filters = inplanes
|
176 |
+
if grow_first:
|
177 |
+
rep.append(self.relu)
|
178 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
179 |
+
# rep.append(nn.BatchNorm2d(planes))
|
180 |
+
filters = planes
|
181 |
+
|
182 |
+
for i in range(reps - 1):
|
183 |
+
rep.append(self.relu)
|
184 |
+
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
|
185 |
+
# rep.append(nn.BatchNorm2d(filters))
|
186 |
+
|
187 |
+
if not grow_first:
|
188 |
+
rep.append(self.relu)
|
189 |
+
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
|
190 |
+
# rep.append(nn.BatchNorm2d(planes))
|
191 |
+
|
192 |
+
if not start_with_relu:
|
193 |
+
rep = rep[1:]
|
194 |
+
|
195 |
+
if stride != 1:
|
196 |
+
self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)])
|
197 |
+
|
198 |
+
if is_last:
|
199 |
+
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1))
|
200 |
+
|
201 |
+
|
202 |
+
self.rep = nn.Sequential(*rep)
|
203 |
+
|
204 |
+
def forward(self, inp):
|
205 |
+
x = self.rep(inp)
|
206 |
+
low_middle = x.clone()
|
207 |
+
x1 = x
|
208 |
+
x1 = self.block2_lastconv(x1)
|
209 |
+
if self.skip is not None:
|
210 |
+
skip = self.skip(inp)
|
211 |
+
skip = self.skipbn(skip)
|
212 |
+
else:
|
213 |
+
skip = inp
|
214 |
+
|
215 |
+
x1 += skip
|
216 |
+
|
217 |
+
return x1,low_middle
|
218 |
+
|
219 |
+
class Xception(nn.Module):
|
220 |
+
"""
|
221 |
+
Modified Alighed Xception
|
222 |
+
"""
|
223 |
+
def __init__(self, inplanes=3, os=16, pretrained=False):
|
224 |
+
super(Xception, self).__init__()
|
225 |
+
|
226 |
+
if os == 16:
|
227 |
+
entry_block3_stride = 2
|
228 |
+
middle_block_rate = 1
|
229 |
+
exit_block_rates = (1, 2)
|
230 |
+
elif os == 8:
|
231 |
+
entry_block3_stride = 1
|
232 |
+
middle_block_rate = 2
|
233 |
+
exit_block_rates = (2, 4)
|
234 |
+
else:
|
235 |
+
raise NotImplementedError
|
236 |
+
|
237 |
+
|
238 |
+
# Entry flow
|
239 |
+
self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False)
|
240 |
+
self.bn1 = SynchronizedBatchNorm2d(32)
|
241 |
+
self.relu = nn.ReLU(inplace=True)
|
242 |
+
|
243 |
+
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
|
244 |
+
self.bn2 = SynchronizedBatchNorm2d(64)
|
245 |
+
|
246 |
+
self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)
|
247 |
+
self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
|
248 |
+
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True)
|
249 |
+
|
250 |
+
# Middle flow
|
251 |
+
self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
252 |
+
self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
253 |
+
self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
254 |
+
self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
255 |
+
self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
256 |
+
self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
257 |
+
self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
258 |
+
self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
259 |
+
self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
260 |
+
self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
261 |
+
self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
262 |
+
self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
263 |
+
self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
264 |
+
self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
265 |
+
self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
266 |
+
self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
|
267 |
+
|
268 |
+
# Exit flow
|
269 |
+
self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0],
|
270 |
+
start_with_relu=True, grow_first=False, is_last=True)
|
271 |
+
|
272 |
+
self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
|
273 |
+
# self.bn3 = nn.BatchNorm2d(1536)
|
274 |
+
|
275 |
+
self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
|
276 |
+
# self.bn4 = nn.BatchNorm2d(1536)
|
277 |
+
|
278 |
+
self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
|
279 |
+
# self.bn5 = nn.BatchNorm2d(2048)
|
280 |
+
|
281 |
+
# Init weights
|
282 |
+
# self.__init_weight()
|
283 |
+
|
284 |
+
# Load pretrained model
|
285 |
+
if pretrained:
|
286 |
+
self.__load_xception_pretrained()
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
# Entry flow
|
290 |
+
x = self.conv1(x)
|
291 |
+
x = self.bn1(x)
|
292 |
+
x = self.relu(x)
|
293 |
+
# print('conv1 ',x.size())
|
294 |
+
x = self.conv2(x)
|
295 |
+
x = self.bn2(x)
|
296 |
+
x = self.relu(x)
|
297 |
+
|
298 |
+
x = self.block1(x)
|
299 |
+
# print('block1',x.size())
|
300 |
+
# low_level_feat = x
|
301 |
+
x,low_level_feat = self.block2(x)
|
302 |
+
# print('block2',x.size())
|
303 |
+
x = self.block3(x)
|
304 |
+
# print('xception block3 ',x.size())
|
305 |
+
|
306 |
+
# Middle flow
|
307 |
+
x = self.block4(x)
|
308 |
+
x = self.block5(x)
|
309 |
+
x = self.block6(x)
|
310 |
+
x = self.block7(x)
|
311 |
+
x = self.block8(x)
|
312 |
+
x = self.block9(x)
|
313 |
+
x = self.block10(x)
|
314 |
+
x = self.block11(x)
|
315 |
+
x = self.block12(x)
|
316 |
+
x = self.block13(x)
|
317 |
+
x = self.block14(x)
|
318 |
+
x = self.block15(x)
|
319 |
+
x = self.block16(x)
|
320 |
+
x = self.block17(x)
|
321 |
+
x = self.block18(x)
|
322 |
+
x = self.block19(x)
|
323 |
+
|
324 |
+
# Exit flow
|
325 |
+
x = self.block20(x)
|
326 |
+
x = self.conv3(x)
|
327 |
+
# x = self.bn3(x)
|
328 |
+
x = self.relu(x)
|
329 |
+
|
330 |
+
x = self.conv4(x)
|
331 |
+
# x = self.bn4(x)
|
332 |
+
x = self.relu(x)
|
333 |
+
|
334 |
+
x = self.conv5(x)
|
335 |
+
# x = self.bn5(x)
|
336 |
+
x = self.relu(x)
|
337 |
+
|
338 |
+
return x, low_level_feat
|
339 |
+
|
340 |
+
def __init_weight(self):
|
341 |
+
for m in self.modules():
|
342 |
+
if isinstance(m, nn.Conv2d):
|
343 |
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
344 |
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
345 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
346 |
+
elif isinstance(m, nn.BatchNorm2d):
|
347 |
+
m.weight.data.fill_(1)
|
348 |
+
m.bias.data.zero_()
|
349 |
+
|
350 |
+
def __load_xception_pretrained(self):
|
351 |
+
pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
|
352 |
+
model_dict = {}
|
353 |
+
state_dict = self.state_dict()
|
354 |
+
|
355 |
+
for k, v in pretrain_dict.items():
|
356 |
+
if k in state_dict:
|
357 |
+
if 'pointwise' in k:
|
358 |
+
v = v.unsqueeze(-1).unsqueeze(-1)
|
359 |
+
if k.startswith('block12'):
|
360 |
+
model_dict[k.replace('block12', 'block20')] = v
|
361 |
+
elif k.startswith('block11'):
|
362 |
+
model_dict[k.replace('block11', 'block12')] = v
|
363 |
+
model_dict[k.replace('block11', 'block13')] = v
|
364 |
+
model_dict[k.replace('block11', 'block14')] = v
|
365 |
+
model_dict[k.replace('block11', 'block15')] = v
|
366 |
+
model_dict[k.replace('block11', 'block16')] = v
|
367 |
+
model_dict[k.replace('block11', 'block17')] = v
|
368 |
+
model_dict[k.replace('block11', 'block18')] = v
|
369 |
+
model_dict[k.replace('block11', 'block19')] = v
|
370 |
+
elif k.startswith('conv3'):
|
371 |
+
model_dict[k] = v
|
372 |
+
elif k.startswith('bn3'):
|
373 |
+
model_dict[k] = v
|
374 |
+
model_dict[k.replace('bn3', 'bn4')] = v
|
375 |
+
elif k.startswith('conv4'):
|
376 |
+
model_dict[k.replace('conv4', 'conv5')] = v
|
377 |
+
elif k.startswith('bn4'):
|
378 |
+
model_dict[k.replace('bn4', 'bn5')] = v
|
379 |
+
else:
|
380 |
+
model_dict[k] = v
|
381 |
+
state_dict.update(model_dict)
|
382 |
+
self.load_state_dict(state_dict)
|
383 |
+
|
384 |
+
class DeepLabv3_plus(nn.Module):
|
385 |
+
def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True):
|
386 |
+
if _print:
|
387 |
+
print("Constructing DeepLabv3+ model...")
|
388 |
+
print("Number of classes: {}".format(n_classes))
|
389 |
+
print("Output stride: {}".format(os))
|
390 |
+
print("Number of Input Channels: {}".format(nInputChannels))
|
391 |
+
super(DeepLabv3_plus, self).__init__()
|
392 |
+
|
393 |
+
# Atrous Conv
|
394 |
+
self.xception_features = Xception(nInputChannels, os, pretrained)
|
395 |
+
|
396 |
+
# ASPP
|
397 |
+
if os == 16:
|
398 |
+
rates = [1, 6, 12, 18]
|
399 |
+
elif os == 8:
|
400 |
+
rates = [1, 12, 24, 36]
|
401 |
+
else:
|
402 |
+
raise NotImplementedError
|
403 |
+
|
404 |
+
self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0])
|
405 |
+
self.aspp2 = ASPP_module(2048, 256, rate=rates[1])
|
406 |
+
self.aspp3 = ASPP_module(2048, 256, rate=rates[2])
|
407 |
+
self.aspp4 = ASPP_module(2048, 256, rate=rates[3])
|
408 |
+
|
409 |
+
self.relu = nn.ReLU()
|
410 |
+
|
411 |
+
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
412 |
+
nn.Conv2d(2048, 256, 1, stride=1, bias=False),
|
413 |
+
SynchronizedBatchNorm2d(256),
|
414 |
+
nn.ReLU()
|
415 |
+
)
|
416 |
+
|
417 |
+
self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False)
|
418 |
+
self.concat_projection_bn1 = SynchronizedBatchNorm2d(256)
|
419 |
+
|
420 |
+
# adopt [1x1, 48] for channel reduction.
|
421 |
+
self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False)
|
422 |
+
self.feature_projection_bn1 = SynchronizedBatchNorm2d(48)
|
423 |
+
|
424 |
+
self.decoder = nn.Sequential(Decoder_module(304, 256),
|
425 |
+
Decoder_module(256, 256)
|
426 |
+
)
|
427 |
+
self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1)
|
428 |
+
|
429 |
+
def forward(self, input):
|
430 |
+
x, low_level_features = self.xception_features(input)
|
431 |
+
# print(x.size())
|
432 |
+
x1 = self.aspp1(x)
|
433 |
+
x2 = self.aspp2(x)
|
434 |
+
x3 = self.aspp3(x)
|
435 |
+
x4 = self.aspp4(x)
|
436 |
+
x5 = self.global_avg_pool(x)
|
437 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
438 |
+
|
439 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
440 |
+
|
441 |
+
x = self.concat_projection_conv1(x)
|
442 |
+
x = self.concat_projection_bn1(x)
|
443 |
+
x = self.relu(x)
|
444 |
+
# print(x.size())
|
445 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
446 |
+
|
447 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
448 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
449 |
+
low_level_features = self.relu(low_level_features)
|
450 |
+
# print(low_level_features.size())
|
451 |
+
# print(x.size())
|
452 |
+
x = torch.cat((x, low_level_features), dim=1)
|
453 |
+
x = self.decoder(x)
|
454 |
+
x = self.semantic(x)
|
455 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
456 |
+
|
457 |
+
return x
|
458 |
+
|
459 |
+
def freeze_bn(self):
|
460 |
+
for m in self.xception_features.modules():
|
461 |
+
if isinstance(m, nn.BatchNorm2d) or isinstance(m,SynchronizedBatchNorm2d):
|
462 |
+
m.eval()
|
463 |
+
|
464 |
+
def freeze_aspp_bn(self):
|
465 |
+
for m in self.aspp1.modules():
|
466 |
+
if isinstance(m, nn.BatchNorm2d):
|
467 |
+
m.eval()
|
468 |
+
for m in self.aspp2.modules():
|
469 |
+
if isinstance(m, nn.BatchNorm2d):
|
470 |
+
m.eval()
|
471 |
+
for m in self.aspp3.modules():
|
472 |
+
if isinstance(m, nn.BatchNorm2d):
|
473 |
+
m.eval()
|
474 |
+
for m in self.aspp4.modules():
|
475 |
+
if isinstance(m, nn.BatchNorm2d):
|
476 |
+
m.eval()
|
477 |
+
|
478 |
+
def learnable_parameters(self):
|
479 |
+
layer_features_BN = []
|
480 |
+
layer_features = []
|
481 |
+
layer_aspp = []
|
482 |
+
layer_projection =[]
|
483 |
+
layer_decoder = []
|
484 |
+
layer_other = []
|
485 |
+
model_para = list(self.named_parameters())
|
486 |
+
for name,para in model_para:
|
487 |
+
if 'xception' in name:
|
488 |
+
if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name:
|
489 |
+
layer_features_BN.append(para)
|
490 |
+
else:
|
491 |
+
layer_features.append(para)
|
492 |
+
# print (name)
|
493 |
+
elif 'aspp' in name:
|
494 |
+
layer_aspp.append(para)
|
495 |
+
elif 'projection' in name:
|
496 |
+
layer_projection.append(para)
|
497 |
+
elif 'decode' in name:
|
498 |
+
layer_decoder.append(para)
|
499 |
+
else:
|
500 |
+
layer_other.append(para)
|
501 |
+
return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other
|
502 |
+
|
503 |
+
|
504 |
+
def __init_weight(self):
|
505 |
+
for m in self.modules():
|
506 |
+
if isinstance(m, nn.Conv2d):
|
507 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
508 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
509 |
+
# torch.nn.init.kaiming_normal_(m.weight)
|
510 |
+
elif isinstance(m, nn.BatchNorm2d):
|
511 |
+
m.weight.data.fill_(1)
|
512 |
+
m.bias.data.zero_()
|
513 |
+
|
514 |
+
def load_state_dict_new(self, state_dict):
|
515 |
+
own_state = self.state_dict()
|
516 |
+
#for name inshop_cos own_state:
|
517 |
+
# print name
|
518 |
+
new_state_dict = OrderedDict()
|
519 |
+
for name, param in state_dict.items():
|
520 |
+
name = name.replace('module.','')
|
521 |
+
new_state_dict[name] = 0
|
522 |
+
if name not in own_state:
|
523 |
+
if 'num_batch' in name:
|
524 |
+
continue
|
525 |
+
print ('unexpected key "{}" in state_dict'
|
526 |
+
.format(name))
|
527 |
+
continue
|
528 |
+
# if isinstance(param, own_state):
|
529 |
+
if isinstance(param, Parameter):
|
530 |
+
# backwards compatibility for serialized parameters
|
531 |
+
param = param.data
|
532 |
+
try:
|
533 |
+
own_state[name].copy_(param)
|
534 |
+
except:
|
535 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
536 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
537 |
+
name, own_state[name].size(), param.size()))
|
538 |
+
continue # i add inshop_cos 2018/02/01
|
539 |
+
# raise
|
540 |
+
# print 'copying %s' %name
|
541 |
+
# if isinstance(param, own_state):
|
542 |
+
# backwards compatibility for serialized parameters
|
543 |
+
own_state[name].copy_(param)
|
544 |
+
# print 'copying %s' %name
|
545 |
+
|
546 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
547 |
+
if len(missing) > 0:
|
548 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
549 |
+
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
def get_1x_lr_params(model):
|
554 |
+
"""
|
555 |
+
This generator returns all the parameters of the net except for
|
556 |
+
the last classification layer. Note that for each batchnorm layer,
|
557 |
+
requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
|
558 |
+
any batchnorm parameter
|
559 |
+
"""
|
560 |
+
b = [model.xception_features]
|
561 |
+
for i in range(len(b)):
|
562 |
+
for k in b[i].parameters():
|
563 |
+
if k.requires_grad:
|
564 |
+
yield k
|
565 |
+
|
566 |
+
|
567 |
+
def get_10x_lr_params(model):
|
568 |
+
"""
|
569 |
+
This generator returns all the parameters for the last layer of the net,
|
570 |
+
which does the classification of pixel into classes
|
571 |
+
"""
|
572 |
+
b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
|
573 |
+
for j in range(len(b)):
|
574 |
+
for k in b[j].parameters():
|
575 |
+
if k.requires_grad:
|
576 |
+
yield k
|
577 |
+
|
578 |
+
|
579 |
+
if __name__ == "__main__":
|
580 |
+
model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True)
|
581 |
+
model.eval()
|
582 |
+
# ckt = torch.load('C:\\Users\gaoyi\code_python\deeplab_v3plus.pth')
|
583 |
+
# model.load_state_dict_new(ckt)
|
584 |
+
|
585 |
+
|
586 |
+
image = torch.randn(1, 3, 512, 512)*255
|
587 |
+
with torch.no_grad():
|
588 |
+
output = model.forward(image)
|
589 |
+
print(output.size())
|
590 |
+
# print(output)
|
591 |
+
|
592 |
+
|
593 |
+
|
594 |
+
|
595 |
+
|
596 |
+
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_transfer.py
ADDED
@@ -0,0 +1,1003 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
import numpy as np
|
8 |
+
from collections import OrderedDict
|
9 |
+
from torch.nn import Parameter
|
10 |
+
from networks import deeplab_xception,gcn, deeplab_xception_synBN
|
11 |
+
import pdb
|
12 |
+
|
13 |
+
#######################
|
14 |
+
# base model
|
15 |
+
#######################
|
16 |
+
|
17 |
+
class deeplab_xception_transfer_basemodel(deeplab_xception.DeepLabv3_plus):
|
18 |
+
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
|
19 |
+
super(deeplab_xception_transfer_basemodel, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
20 |
+
os=os,)
|
21 |
+
### source graph
|
22 |
+
# self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
23 |
+
# nodes=n_classes)
|
24 |
+
# self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
25 |
+
# self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
26 |
+
# self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
27 |
+
#
|
28 |
+
# self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
|
29 |
+
# hidden_layers=hidden_layers, nodes=n_classes
|
30 |
+
# )
|
31 |
+
# self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
32 |
+
# nn.ReLU(True)])
|
33 |
+
|
34 |
+
### target graph
|
35 |
+
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
36 |
+
nodes=n_classes)
|
37 |
+
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
38 |
+
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
39 |
+
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
40 |
+
|
41 |
+
self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
|
42 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
43 |
+
)
|
44 |
+
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
45 |
+
nn.ReLU(True)])
|
46 |
+
|
47 |
+
def load_source_model(self,state_dict):
|
48 |
+
own_state = self.state_dict()
|
49 |
+
# for name inshop_cos own_state:
|
50 |
+
# print name
|
51 |
+
new_state_dict = OrderedDict()
|
52 |
+
for name, param in state_dict.items():
|
53 |
+
name = name.replace('module.', '')
|
54 |
+
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name:
|
55 |
+
if 'featuremap_2_graph' in name:
|
56 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
57 |
+
else:
|
58 |
+
name = name.replace('graph','source_graph')
|
59 |
+
new_state_dict[name] = 0
|
60 |
+
if name not in own_state:
|
61 |
+
if 'num_batch' in name:
|
62 |
+
continue
|
63 |
+
print('unexpected key "{}" in state_dict'
|
64 |
+
.format(name))
|
65 |
+
continue
|
66 |
+
# if isinstance(param, own_state):
|
67 |
+
if isinstance(param, Parameter):
|
68 |
+
# backwards compatibility for serialized parameters
|
69 |
+
param = param.data
|
70 |
+
try:
|
71 |
+
own_state[name].copy_(param)
|
72 |
+
except:
|
73 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
74 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
75 |
+
name, own_state[name].size(), param.size()))
|
76 |
+
continue # i add inshop_cos 2018/02/01
|
77 |
+
own_state[name].copy_(param)
|
78 |
+
# print 'copying %s' %name
|
79 |
+
|
80 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
81 |
+
if len(missing) > 0:
|
82 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
83 |
+
|
84 |
+
def get_target_parameter(self):
|
85 |
+
l = []
|
86 |
+
other = []
|
87 |
+
for name, k in self.named_parameters():
|
88 |
+
if 'target' in name or 'semantic' in name:
|
89 |
+
l.append(k)
|
90 |
+
else:
|
91 |
+
other.append(k)
|
92 |
+
return l, other
|
93 |
+
|
94 |
+
def get_semantic_parameter(self):
|
95 |
+
l = []
|
96 |
+
for name, k in self.named_parameters():
|
97 |
+
if 'semantic' in name:
|
98 |
+
l.append(k)
|
99 |
+
return l
|
100 |
+
|
101 |
+
def get_source_parameter(self):
|
102 |
+
l = []
|
103 |
+
for name, k in self.named_parameters():
|
104 |
+
if 'source' in name:
|
105 |
+
l.append(k)
|
106 |
+
return l
|
107 |
+
|
108 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
109 |
+
x, low_level_features = self.xception_features(input)
|
110 |
+
# print(x.size())
|
111 |
+
x1 = self.aspp1(x)
|
112 |
+
x2 = self.aspp2(x)
|
113 |
+
x3 = self.aspp3(x)
|
114 |
+
x4 = self.aspp4(x)
|
115 |
+
x5 = self.global_avg_pool(x)
|
116 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
117 |
+
|
118 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
119 |
+
|
120 |
+
x = self.concat_projection_conv1(x)
|
121 |
+
x = self.concat_projection_bn1(x)
|
122 |
+
x = self.relu(x)
|
123 |
+
# print(x.size())
|
124 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
125 |
+
|
126 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
127 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
128 |
+
low_level_features = self.relu(low_level_features)
|
129 |
+
# print(low_level_features.size())
|
130 |
+
# print(x.size())
|
131 |
+
x = torch.cat((x, low_level_features), dim=1)
|
132 |
+
x = self.decoder(x)
|
133 |
+
|
134 |
+
### add graph
|
135 |
+
|
136 |
+
|
137 |
+
# target graph
|
138 |
+
# print('x size',x.size(),adj1.size())
|
139 |
+
graph = self.target_featuremap_2_graph(x)
|
140 |
+
|
141 |
+
# graph combine
|
142 |
+
# print(graph.size(),source_2_target_graph.size())
|
143 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
144 |
+
# print(graph.size())
|
145 |
+
|
146 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
147 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
148 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
149 |
+
# print(graph.size(),x.size())
|
150 |
+
# graph = self.gcn_encode.forward(graph,relu=True)
|
151 |
+
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
152 |
+
# graph = self.gcn_decode.forward(graph,relu=True)
|
153 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
154 |
+
x = self.target_skip_conv(x)
|
155 |
+
x = x + graph
|
156 |
+
|
157 |
+
###
|
158 |
+
x = self.semantic(x)
|
159 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
160 |
+
|
161 |
+
return x
|
162 |
+
|
163 |
+
class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus):
|
164 |
+
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
|
165 |
+
super(deeplab_xception_transfer_basemodel_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
166 |
+
os=os,)
|
167 |
+
### source graph
|
168 |
+
|
169 |
+
### target graph
|
170 |
+
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
171 |
+
nodes=n_classes)
|
172 |
+
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
173 |
+
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
174 |
+
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
175 |
+
|
176 |
+
self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels,
|
177 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
178 |
+
)
|
179 |
+
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
180 |
+
nn.ReLU(True)])
|
181 |
+
|
182 |
+
def load_source_model(self,state_dict):
|
183 |
+
own_state = self.state_dict()
|
184 |
+
# for name inshop_cos own_state:
|
185 |
+
# print name
|
186 |
+
new_state_dict = OrderedDict()
|
187 |
+
for name, param in state_dict.items():
|
188 |
+
name = name.replace('module.', '')
|
189 |
+
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name:
|
190 |
+
if 'featuremap_2_graph' in name:
|
191 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
192 |
+
else:
|
193 |
+
name = name.replace('graph','source_graph')
|
194 |
+
new_state_dict[name] = 0
|
195 |
+
if name not in own_state:
|
196 |
+
if 'num_batch' in name:
|
197 |
+
continue
|
198 |
+
print('unexpected key "{}" in state_dict'
|
199 |
+
.format(name))
|
200 |
+
continue
|
201 |
+
# if isinstance(param, own_state):
|
202 |
+
if isinstance(param, Parameter):
|
203 |
+
# backwards compatibility for serialized parameters
|
204 |
+
param = param.data
|
205 |
+
try:
|
206 |
+
own_state[name].copy_(param)
|
207 |
+
except:
|
208 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
209 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
210 |
+
name, own_state[name].size(), param.size()))
|
211 |
+
continue # i add inshop_cos 2018/02/01
|
212 |
+
own_state[name].copy_(param)
|
213 |
+
# print 'copying %s' %name
|
214 |
+
|
215 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
216 |
+
if len(missing) > 0:
|
217 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
218 |
+
|
219 |
+
def get_target_parameter(self):
|
220 |
+
l = []
|
221 |
+
other = []
|
222 |
+
for name, k in self.named_parameters():
|
223 |
+
if 'target' in name or 'semantic' in name:
|
224 |
+
l.append(k)
|
225 |
+
else:
|
226 |
+
other.append(k)
|
227 |
+
return l, other
|
228 |
+
|
229 |
+
def get_semantic_parameter(self):
|
230 |
+
l = []
|
231 |
+
for name, k in self.named_parameters():
|
232 |
+
if 'semantic' in name:
|
233 |
+
l.append(k)
|
234 |
+
return l
|
235 |
+
|
236 |
+
def get_source_parameter(self):
|
237 |
+
l = []
|
238 |
+
for name, k in self.named_parameters():
|
239 |
+
if 'source' in name:
|
240 |
+
l.append(k)
|
241 |
+
return l
|
242 |
+
|
243 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
244 |
+
x, low_level_features = self.xception_features(input)
|
245 |
+
# print(x.size())
|
246 |
+
x1 = self.aspp1(x)
|
247 |
+
x2 = self.aspp2(x)
|
248 |
+
x3 = self.aspp3(x)
|
249 |
+
x4 = self.aspp4(x)
|
250 |
+
x5 = self.global_avg_pool(x)
|
251 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
252 |
+
|
253 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
254 |
+
|
255 |
+
x = self.concat_projection_conv1(x)
|
256 |
+
x = self.concat_projection_bn1(x)
|
257 |
+
x = self.relu(x)
|
258 |
+
# print(x.size())
|
259 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
260 |
+
|
261 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
262 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
263 |
+
low_level_features = self.relu(low_level_features)
|
264 |
+
# print(low_level_features.size())
|
265 |
+
# print(x.size())
|
266 |
+
x = torch.cat((x, low_level_features), dim=1)
|
267 |
+
x = self.decoder(x)
|
268 |
+
|
269 |
+
### add graph
|
270 |
+
|
271 |
+
|
272 |
+
# target graph
|
273 |
+
# print('x size',x.size(),adj1.size())
|
274 |
+
graph = self.target_featuremap_2_graph(x)
|
275 |
+
|
276 |
+
# graph combine
|
277 |
+
# print(graph.size(),source_2_target_graph.size())
|
278 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
279 |
+
# print(graph.size())
|
280 |
+
|
281 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
282 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
283 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
284 |
+
# print(graph.size(),x.size())
|
285 |
+
# graph = self.gcn_encode.forward(graph,relu=True)
|
286 |
+
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
287 |
+
# graph = self.gcn_decode.forward(graph,relu=True)
|
288 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
289 |
+
x = self.target_skip_conv(x)
|
290 |
+
x = x + graph
|
291 |
+
|
292 |
+
###
|
293 |
+
x = self.semantic(x)
|
294 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
295 |
+
|
296 |
+
return x
|
297 |
+
|
298 |
+
class deeplab_xception_transfer_basemodel_synBN(deeplab_xception_synBN.DeepLabv3_plus):
|
299 |
+
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
|
300 |
+
super(deeplab_xception_transfer_basemodel_synBN, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
301 |
+
os=os,)
|
302 |
+
### source graph
|
303 |
+
# self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
304 |
+
# nodes=n_classes)
|
305 |
+
# self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
306 |
+
# self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
307 |
+
# self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
308 |
+
#
|
309 |
+
# self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
|
310 |
+
# hidden_layers=hidden_layers, nodes=n_classes
|
311 |
+
# )
|
312 |
+
# self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
313 |
+
# nn.ReLU(True)])
|
314 |
+
|
315 |
+
### target graph
|
316 |
+
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
317 |
+
nodes=n_classes)
|
318 |
+
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
319 |
+
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
320 |
+
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
321 |
+
|
322 |
+
self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
|
323 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
324 |
+
)
|
325 |
+
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
326 |
+
nn.ReLU(True)])
|
327 |
+
|
328 |
+
def load_source_model(self,state_dict):
|
329 |
+
own_state = self.state_dict()
|
330 |
+
# for name inshop_cos own_state:
|
331 |
+
# print name
|
332 |
+
new_state_dict = OrderedDict()
|
333 |
+
for name, param in state_dict.items():
|
334 |
+
name = name.replace('module.', '')
|
335 |
+
|
336 |
+
if 'graph' in name and 'source' not in name and 'target' not in name:
|
337 |
+
if 'featuremap_2_graph' in name:
|
338 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
339 |
+
else:
|
340 |
+
name = name.replace('graph','source_graph')
|
341 |
+
new_state_dict[name] = 0
|
342 |
+
if name not in own_state:
|
343 |
+
if 'num_batch' in name:
|
344 |
+
continue
|
345 |
+
print('unexpected key "{}" in state_dict'
|
346 |
+
.format(name))
|
347 |
+
continue
|
348 |
+
# if isinstance(param, own_state):
|
349 |
+
if isinstance(param, Parameter):
|
350 |
+
# backwards compatibility for serialized parameters
|
351 |
+
param = param.data
|
352 |
+
try:
|
353 |
+
own_state[name].copy_(param)
|
354 |
+
except:
|
355 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
356 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
357 |
+
name, own_state[name].size(), param.size()))
|
358 |
+
continue # i add inshop_cos 2018/02/01
|
359 |
+
own_state[name].copy_(param)
|
360 |
+
# print 'copying %s' %name
|
361 |
+
|
362 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
363 |
+
if len(missing) > 0:
|
364 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
365 |
+
|
366 |
+
def get_target_parameter(self):
|
367 |
+
l = []
|
368 |
+
other = []
|
369 |
+
for name, k in self.named_parameters():
|
370 |
+
if 'target' in name or 'semantic' in name:
|
371 |
+
l.append(k)
|
372 |
+
else:
|
373 |
+
other.append(k)
|
374 |
+
return l, other
|
375 |
+
|
376 |
+
def get_semantic_parameter(self):
|
377 |
+
l = []
|
378 |
+
for name, k in self.named_parameters():
|
379 |
+
if 'semantic' in name:
|
380 |
+
l.append(k)
|
381 |
+
return l
|
382 |
+
|
383 |
+
def get_source_parameter(self):
|
384 |
+
l = []
|
385 |
+
for name, k in self.named_parameters():
|
386 |
+
if 'source' in name:
|
387 |
+
l.append(k)
|
388 |
+
return l
|
389 |
+
|
390 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
391 |
+
x, low_level_features = self.xception_features(input)
|
392 |
+
# print(x.size())
|
393 |
+
x1 = self.aspp1(x)
|
394 |
+
x2 = self.aspp2(x)
|
395 |
+
x3 = self.aspp3(x)
|
396 |
+
x4 = self.aspp4(x)
|
397 |
+
x5 = self.global_avg_pool(x)
|
398 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
399 |
+
|
400 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
401 |
+
|
402 |
+
x = self.concat_projection_conv1(x)
|
403 |
+
x = self.concat_projection_bn1(x)
|
404 |
+
x = self.relu(x)
|
405 |
+
# print(x.size())
|
406 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
407 |
+
|
408 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
409 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
410 |
+
low_level_features = self.relu(low_level_features)
|
411 |
+
# print(low_level_features.size())
|
412 |
+
# print(x.size())
|
413 |
+
x = torch.cat((x, low_level_features), dim=1)
|
414 |
+
x = self.decoder(x)
|
415 |
+
|
416 |
+
### add graph
|
417 |
+
|
418 |
+
|
419 |
+
# target graph
|
420 |
+
# print('x size',x.size(),adj1.size())
|
421 |
+
graph = self.target_featuremap_2_graph(x)
|
422 |
+
|
423 |
+
# graph combine
|
424 |
+
# print(graph.size(),source_2_target_graph.size())
|
425 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
426 |
+
# print(graph.size())
|
427 |
+
|
428 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
429 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
430 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
431 |
+
# print(graph.size(),x.size())
|
432 |
+
# graph = self.gcn_encode.forward(graph,relu=True)
|
433 |
+
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
434 |
+
# graph = self.gcn_decode.forward(graph,relu=True)
|
435 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
436 |
+
x = self.target_skip_conv(x)
|
437 |
+
x = x + graph
|
438 |
+
|
439 |
+
###
|
440 |
+
x = self.semantic(x)
|
441 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
442 |
+
|
443 |
+
return x
|
444 |
+
|
445 |
+
class deeplab_xception_transfer_basemodel_synBN_savememory(deeplab_xception_synBN.DeepLabv3_plus):
|
446 |
+
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
|
447 |
+
super(deeplab_xception_transfer_basemodel_synBN_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
448 |
+
os=os, )
|
449 |
+
### source graph
|
450 |
+
# self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
451 |
+
# nodes=n_classes)
|
452 |
+
# self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
453 |
+
# self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
454 |
+
# self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
455 |
+
#
|
456 |
+
# self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
|
457 |
+
# hidden_layers=hidden_layers, nodes=n_classes
|
458 |
+
# )
|
459 |
+
# self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
460 |
+
# nn.ReLU(True)])
|
461 |
+
|
462 |
+
### target graph
|
463 |
+
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
464 |
+
nodes=n_classes)
|
465 |
+
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
466 |
+
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
467 |
+
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
468 |
+
|
469 |
+
self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels,
|
470 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
471 |
+
)
|
472 |
+
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
473 |
+
nn.BatchNorm2d(input_channels),
|
474 |
+
nn.ReLU(True)])
|
475 |
+
|
476 |
+
def load_source_model(self,state_dict):
|
477 |
+
own_state = self.state_dict()
|
478 |
+
# for name inshop_cos own_state:
|
479 |
+
# print name
|
480 |
+
new_state_dict = OrderedDict()
|
481 |
+
for name, param in state_dict.items():
|
482 |
+
name = name.replace('module.', '')
|
483 |
+
|
484 |
+
if 'graph' in name and 'source' not in name and 'target' not in name:
|
485 |
+
if 'featuremap_2_graph' in name:
|
486 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
487 |
+
else:
|
488 |
+
name = name.replace('graph','source_graph')
|
489 |
+
new_state_dict[name] = 0
|
490 |
+
if name not in own_state:
|
491 |
+
if 'num_batch' in name:
|
492 |
+
continue
|
493 |
+
print('unexpected key "{}" in state_dict'
|
494 |
+
.format(name))
|
495 |
+
continue
|
496 |
+
# if isinstance(param, own_state):
|
497 |
+
if isinstance(param, Parameter):
|
498 |
+
# backwards compatibility for serialized parameters
|
499 |
+
param = param.data
|
500 |
+
try:
|
501 |
+
own_state[name].copy_(param)
|
502 |
+
except:
|
503 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
504 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
505 |
+
name, own_state[name].size(), param.size()))
|
506 |
+
continue # i add inshop_cos 2018/02/01
|
507 |
+
own_state[name].copy_(param)
|
508 |
+
# print 'copying %s' %name
|
509 |
+
|
510 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
511 |
+
if len(missing) > 0:
|
512 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
513 |
+
|
514 |
+
def get_target_parameter(self):
|
515 |
+
l = []
|
516 |
+
other = []
|
517 |
+
for name, k in self.named_parameters():
|
518 |
+
if 'target' in name or 'semantic' in name:
|
519 |
+
l.append(k)
|
520 |
+
else:
|
521 |
+
other.append(k)
|
522 |
+
return l, other
|
523 |
+
|
524 |
+
def get_semantic_parameter(self):
|
525 |
+
l = []
|
526 |
+
for name, k in self.named_parameters():
|
527 |
+
if 'semantic' in name:
|
528 |
+
l.append(k)
|
529 |
+
return l
|
530 |
+
|
531 |
+
def get_source_parameter(self):
|
532 |
+
l = []
|
533 |
+
for name, k in self.named_parameters():
|
534 |
+
if 'source' in name:
|
535 |
+
l.append(k)
|
536 |
+
return l
|
537 |
+
|
538 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
539 |
+
x, low_level_features = self.xception_features(input)
|
540 |
+
# print(x.size())
|
541 |
+
x1 = self.aspp1(x)
|
542 |
+
x2 = self.aspp2(x)
|
543 |
+
x3 = self.aspp3(x)
|
544 |
+
x4 = self.aspp4(x)
|
545 |
+
x5 = self.global_avg_pool(x)
|
546 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
547 |
+
|
548 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
549 |
+
|
550 |
+
x = self.concat_projection_conv1(x)
|
551 |
+
x = self.concat_projection_bn1(x)
|
552 |
+
x = self.relu(x)
|
553 |
+
# print(x.size())
|
554 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
555 |
+
|
556 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
557 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
558 |
+
low_level_features = self.relu(low_level_features)
|
559 |
+
# print(low_level_features.size())
|
560 |
+
# print(x.size())
|
561 |
+
x = torch.cat((x, low_level_features), dim=1)
|
562 |
+
x = self.decoder(x)
|
563 |
+
|
564 |
+
### add graph
|
565 |
+
|
566 |
+
|
567 |
+
# target graph
|
568 |
+
# print('x size',x.size(),adj1.size())
|
569 |
+
graph = self.target_featuremap_2_graph(x)
|
570 |
+
|
571 |
+
# graph combine
|
572 |
+
# print(graph.size(),source_2_target_graph.size())
|
573 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
574 |
+
# print(graph.size())
|
575 |
+
|
576 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
577 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
578 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
579 |
+
# print(graph.size(),x.size())
|
580 |
+
# graph = self.gcn_encode.forward(graph,relu=True)
|
581 |
+
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
582 |
+
# graph = self.gcn_decode.forward(graph,relu=True)
|
583 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
584 |
+
x = self.target_skip_conv(x)
|
585 |
+
x = x + graph
|
586 |
+
|
587 |
+
###
|
588 |
+
x = self.semantic(x)
|
589 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
590 |
+
|
591 |
+
return x
|
592 |
+
|
593 |
+
#######################
|
594 |
+
# transfer model
|
595 |
+
#######################
|
596 |
+
|
597 |
+
class deeplab_xception_transfer_projection(deeplab_xception_transfer_basemodel):
|
598 |
+
def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
|
599 |
+
transfer_graph=None, source_classes=20):
|
600 |
+
super(deeplab_xception_transfer_projection, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
601 |
+
os=os, input_channels=input_channels,
|
602 |
+
hidden_layers=hidden_layers, out_channels=out_channels, )
|
603 |
+
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
604 |
+
nodes=source_classes)
|
605 |
+
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
606 |
+
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
607 |
+
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
608 |
+
self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
|
609 |
+
begin_nodes=source_classes,end_nodes=n_classes)
|
610 |
+
self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers)
|
611 |
+
|
612 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
613 |
+
x, low_level_features = self.xception_features(input)
|
614 |
+
# print(x.size())
|
615 |
+
x1 = self.aspp1(x)
|
616 |
+
x2 = self.aspp2(x)
|
617 |
+
x3 = self.aspp3(x)
|
618 |
+
x4 = self.aspp4(x)
|
619 |
+
x5 = self.global_avg_pool(x)
|
620 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
621 |
+
|
622 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
623 |
+
|
624 |
+
x = self.concat_projection_conv1(x)
|
625 |
+
x = self.concat_projection_bn1(x)
|
626 |
+
x = self.relu(x)
|
627 |
+
# print(x.size())
|
628 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
629 |
+
|
630 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
631 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
632 |
+
low_level_features = self.relu(low_level_features)
|
633 |
+
# print(low_level_features.size())
|
634 |
+
# print(x.size())
|
635 |
+
x = torch.cat((x, low_level_features), dim=1)
|
636 |
+
x = self.decoder(x)
|
637 |
+
|
638 |
+
### add graph
|
639 |
+
# source graph
|
640 |
+
source_graph = self.source_featuremap_2_graph(x)
|
641 |
+
source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
|
642 |
+
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
|
643 |
+
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
|
644 |
+
|
645 |
+
source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
|
646 |
+
source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
|
647 |
+
source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
|
648 |
+
|
649 |
+
# target graph
|
650 |
+
# print('x size',x.size(),adj1.size())
|
651 |
+
graph = self.target_featuremap_2_graph(x)
|
652 |
+
|
653 |
+
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
|
654 |
+
# graph combine 1
|
655 |
+
# print(graph.size())
|
656 |
+
# print(source_2_target_graph1.size())
|
657 |
+
# print(source_2_target_graph1_v5.size())
|
658 |
+
graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
|
659 |
+
graph = self.fc_graph.forward(graph,relu=True)
|
660 |
+
|
661 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
662 |
+
|
663 |
+
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
|
664 |
+
# graph combine 2
|
665 |
+
graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
|
666 |
+
graph = self.fc_graph.forward(graph, relu=True)
|
667 |
+
|
668 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
669 |
+
|
670 |
+
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
|
671 |
+
# graph combine 3
|
672 |
+
graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
|
673 |
+
graph = self.fc_graph.forward(graph, relu=True)
|
674 |
+
|
675 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
676 |
+
|
677 |
+
# print(graph.size(),x.size())
|
678 |
+
|
679 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
680 |
+
x = self.target_skip_conv(x)
|
681 |
+
x = x + graph
|
682 |
+
|
683 |
+
###
|
684 |
+
x = self.semantic(x)
|
685 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
686 |
+
|
687 |
+
return x
|
688 |
+
|
689 |
+
def similarity_trans(self,source,target):
|
690 |
+
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
|
691 |
+
sim = F.softmax(sim, dim=-1)
|
692 |
+
return torch.matmul(sim, source)
|
693 |
+
|
694 |
+
def load_source_model(self,state_dict):
|
695 |
+
own_state = self.state_dict()
|
696 |
+
# for name inshop_cos own_state:
|
697 |
+
# print name
|
698 |
+
new_state_dict = OrderedDict()
|
699 |
+
for name, param in state_dict.items():
|
700 |
+
name = name.replace('module.', '')
|
701 |
+
|
702 |
+
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
|
703 |
+
if 'featuremap_2_graph' in name:
|
704 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
705 |
+
else:
|
706 |
+
name = name.replace('graph','source_graph')
|
707 |
+
new_state_dict[name] = 0
|
708 |
+
if name not in own_state:
|
709 |
+
if 'num_batch' in name:
|
710 |
+
continue
|
711 |
+
print('unexpected key "{}" in state_dict'
|
712 |
+
.format(name))
|
713 |
+
continue
|
714 |
+
# if isinstance(param, own_state):
|
715 |
+
if isinstance(param, Parameter):
|
716 |
+
# backwards compatibility for serialized parameters
|
717 |
+
param = param.data
|
718 |
+
try:
|
719 |
+
own_state[name].copy_(param)
|
720 |
+
except:
|
721 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
722 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
723 |
+
name, own_state[name].size(), param.size()))
|
724 |
+
continue # i add inshop_cos 2018/02/01
|
725 |
+
own_state[name].copy_(param)
|
726 |
+
# print 'copying %s' %name
|
727 |
+
|
728 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
729 |
+
if len(missing) > 0:
|
730 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
731 |
+
|
732 |
+
class deeplab_xception_transfer_projection_savemem(deeplab_xception_transfer_basemodel_savememory):
|
733 |
+
def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
|
734 |
+
transfer_graph=None, source_classes=20):
|
735 |
+
super(deeplab_xception_transfer_projection_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
736 |
+
os=os, input_channels=input_channels,
|
737 |
+
hidden_layers=hidden_layers, out_channels=out_channels, )
|
738 |
+
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
739 |
+
nodes=source_classes)
|
740 |
+
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
741 |
+
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
742 |
+
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
743 |
+
self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
|
744 |
+
begin_nodes=source_classes,end_nodes=n_classes)
|
745 |
+
self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers)
|
746 |
+
|
747 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
748 |
+
x, low_level_features = self.xception_features(input)
|
749 |
+
# print(x.size())
|
750 |
+
x1 = self.aspp1(x)
|
751 |
+
x2 = self.aspp2(x)
|
752 |
+
x3 = self.aspp3(x)
|
753 |
+
x4 = self.aspp4(x)
|
754 |
+
x5 = self.global_avg_pool(x)
|
755 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
756 |
+
|
757 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
758 |
+
|
759 |
+
x = self.concat_projection_conv1(x)
|
760 |
+
x = self.concat_projection_bn1(x)
|
761 |
+
x = self.relu(x)
|
762 |
+
# print(x.size())
|
763 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
764 |
+
|
765 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
766 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
767 |
+
low_level_features = self.relu(low_level_features)
|
768 |
+
# print(low_level_features.size())
|
769 |
+
# print(x.size())
|
770 |
+
x = torch.cat((x, low_level_features), dim=1)
|
771 |
+
x = self.decoder(x)
|
772 |
+
|
773 |
+
### add graph
|
774 |
+
# source graph
|
775 |
+
source_graph = self.source_featuremap_2_graph(x)
|
776 |
+
source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
|
777 |
+
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
|
778 |
+
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
|
779 |
+
|
780 |
+
source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
|
781 |
+
source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
|
782 |
+
source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
|
783 |
+
|
784 |
+
# target graph
|
785 |
+
# print('x size',x.size(),adj1.size())
|
786 |
+
graph = self.target_featuremap_2_graph(x)
|
787 |
+
|
788 |
+
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
|
789 |
+
# graph combine 1
|
790 |
+
graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
|
791 |
+
graph = self.fc_graph.forward(graph,relu=True)
|
792 |
+
|
793 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
794 |
+
|
795 |
+
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
|
796 |
+
# graph combine 2
|
797 |
+
graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
|
798 |
+
graph = self.fc_graph.forward(graph, relu=True)
|
799 |
+
|
800 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
801 |
+
|
802 |
+
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
|
803 |
+
# graph combine 3
|
804 |
+
graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
|
805 |
+
graph = self.fc_graph.forward(graph, relu=True)
|
806 |
+
|
807 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
808 |
+
|
809 |
+
# print(graph.size(),x.size())
|
810 |
+
|
811 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
812 |
+
x = self.target_skip_conv(x)
|
813 |
+
x = x + graph
|
814 |
+
|
815 |
+
###
|
816 |
+
x = self.semantic(x)
|
817 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
818 |
+
|
819 |
+
return x
|
820 |
+
|
821 |
+
def similarity_trans(self,source,target):
|
822 |
+
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
|
823 |
+
sim = F.softmax(sim, dim=-1)
|
824 |
+
return torch.matmul(sim, source)
|
825 |
+
|
826 |
+
def load_source_model(self,state_dict):
|
827 |
+
own_state = self.state_dict()
|
828 |
+
# for name inshop_cos own_state:
|
829 |
+
# print name
|
830 |
+
new_state_dict = OrderedDict()
|
831 |
+
for name, param in state_dict.items():
|
832 |
+
name = name.replace('module.', '')
|
833 |
+
|
834 |
+
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
|
835 |
+
if 'featuremap_2_graph' in name:
|
836 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
837 |
+
else:
|
838 |
+
name = name.replace('graph','source_graph')
|
839 |
+
new_state_dict[name] = 0
|
840 |
+
if name not in own_state:
|
841 |
+
if 'num_batch' in name:
|
842 |
+
continue
|
843 |
+
print('unexpected key "{}" in state_dict'
|
844 |
+
.format(name))
|
845 |
+
continue
|
846 |
+
# if isinstance(param, own_state):
|
847 |
+
if isinstance(param, Parameter):
|
848 |
+
# backwards compatibility for serialized parameters
|
849 |
+
param = param.data
|
850 |
+
try:
|
851 |
+
own_state[name].copy_(param)
|
852 |
+
except:
|
853 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
854 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
855 |
+
name, own_state[name].size(), param.size()))
|
856 |
+
continue # i add inshop_cos 2018/02/01
|
857 |
+
own_state[name].copy_(param)
|
858 |
+
# print 'copying %s' %name
|
859 |
+
|
860 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
861 |
+
if len(missing) > 0:
|
862 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
863 |
+
|
864 |
+
|
865 |
+
class deeplab_xception_transfer_projection_synBN_savemem(deeplab_xception_transfer_basemodel_synBN_savememory):
|
866 |
+
def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
|
867 |
+
transfer_graph=None, source_classes=20):
|
868 |
+
super(deeplab_xception_transfer_projection_synBN_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
869 |
+
os=os, input_channels=input_channels,
|
870 |
+
hidden_layers=hidden_layers, out_channels=out_channels, )
|
871 |
+
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
872 |
+
nodes=source_classes)
|
873 |
+
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
874 |
+
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
875 |
+
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
876 |
+
self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
|
877 |
+
begin_nodes=source_classes,end_nodes=n_classes)
|
878 |
+
self.fc_graph = gcn.GraphConvolution(hidden_layers*3 ,hidden_layers)
|
879 |
+
|
880 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
881 |
+
x, low_level_features = self.xception_features(input)
|
882 |
+
# print(x.size())
|
883 |
+
x1 = self.aspp1(x)
|
884 |
+
x2 = self.aspp2(x)
|
885 |
+
x3 = self.aspp3(x)
|
886 |
+
x4 = self.aspp4(x)
|
887 |
+
x5 = self.global_avg_pool(x)
|
888 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
889 |
+
|
890 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
891 |
+
|
892 |
+
x = self.concat_projection_conv1(x)
|
893 |
+
x = self.concat_projection_bn1(x)
|
894 |
+
x = self.relu(x)
|
895 |
+
# print(x.size())
|
896 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
897 |
+
|
898 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
899 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
900 |
+
low_level_features = self.relu(low_level_features)
|
901 |
+
# print(low_level_features.size())
|
902 |
+
# print(x.size())
|
903 |
+
x = torch.cat((x, low_level_features), dim=1)
|
904 |
+
x = self.decoder(x)
|
905 |
+
|
906 |
+
### add graph
|
907 |
+
# source graph
|
908 |
+
source_graph = self.source_featuremap_2_graph(x)
|
909 |
+
source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
|
910 |
+
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
|
911 |
+
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
|
912 |
+
|
913 |
+
source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
|
914 |
+
source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
|
915 |
+
source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
|
916 |
+
|
917 |
+
# target graph
|
918 |
+
# print('x size',x.size(),adj1.size())
|
919 |
+
graph = self.target_featuremap_2_graph(x)
|
920 |
+
|
921 |
+
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
|
922 |
+
# graph combine 1
|
923 |
+
graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
|
924 |
+
graph = self.fc_graph.forward(graph,relu=True)
|
925 |
+
|
926 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
927 |
+
|
928 |
+
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
|
929 |
+
# graph combine 2
|
930 |
+
graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
|
931 |
+
graph = self.fc_graph.forward(graph, relu=True)
|
932 |
+
|
933 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
934 |
+
|
935 |
+
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
|
936 |
+
# graph combine 3
|
937 |
+
graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
|
938 |
+
graph = self.fc_graph.forward(graph, relu=True)
|
939 |
+
|
940 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
941 |
+
|
942 |
+
# print(graph.size(),x.size())
|
943 |
+
|
944 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
945 |
+
x = self.target_skip_conv(x)
|
946 |
+
x = x + graph
|
947 |
+
|
948 |
+
###
|
949 |
+
x = self.semantic(x)
|
950 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
951 |
+
|
952 |
+
return x
|
953 |
+
|
954 |
+
def similarity_trans(self,source,target):
|
955 |
+
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
|
956 |
+
sim = F.softmax(sim, dim=-1)
|
957 |
+
return torch.matmul(sim, source)
|
958 |
+
|
959 |
+
def load_source_model(self,state_dict):
|
960 |
+
own_state = self.state_dict()
|
961 |
+
# for name inshop_cos own_state:
|
962 |
+
# print name
|
963 |
+
new_state_dict = OrderedDict()
|
964 |
+
for name, param in state_dict.items():
|
965 |
+
name = name.replace('module.', '')
|
966 |
+
|
967 |
+
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
|
968 |
+
if 'featuremap_2_graph' in name:
|
969 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
970 |
+
else:
|
971 |
+
name = name.replace('graph','source_graph')
|
972 |
+
new_state_dict[name] = 0
|
973 |
+
if name not in own_state:
|
974 |
+
if 'num_batch' in name:
|
975 |
+
continue
|
976 |
+
print('unexpected key "{}" in state_dict'
|
977 |
+
.format(name))
|
978 |
+
continue
|
979 |
+
# if isinstance(param, own_state):
|
980 |
+
if isinstance(param, Parameter):
|
981 |
+
# backwards compatibility for serialized parameters
|
982 |
+
param = param.data
|
983 |
+
try:
|
984 |
+
own_state[name].copy_(param)
|
985 |
+
except:
|
986 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
987 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
988 |
+
name, own_state[name].size(), param.size()))
|
989 |
+
continue # i add inshop_cos 2018/02/01
|
990 |
+
own_state[name].copy_(param)
|
991 |
+
# print 'copying %s' %name
|
992 |
+
|
993 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
994 |
+
if len(missing) > 0:
|
995 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
996 |
+
|
997 |
+
|
998 |
+
# if __name__ == '__main__':
|
999 |
+
# net = deeplab_xception_transfer_projection_v3v5_more_savemem()
|
1000 |
+
# img = torch.rand((2,3,128,128))
|
1001 |
+
# net.eval()
|
1002 |
+
# a = torch.rand((1,1,7,7))
|
1003 |
+
# net.forward(img, adj1_target=a)
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_universal.py
ADDED
@@ -0,0 +1,1077 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
from torch.nn import Parameter
|
8 |
+
from networks import deeplab_xception, gcn, deeplab_xception_synBN
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus):
|
13 |
+
def __init__(self, nInputChannels=3, n_classes=7, os=16, input_channels=256, hidden_layers=128, out_channels=256,
|
14 |
+
source_classes=20, transfer_graph=None):
|
15 |
+
super(deeplab_xception_transfer_basemodel_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
16 |
+
os=os,)
|
17 |
+
|
18 |
+
def load_source_model(self,state_dict):
|
19 |
+
own_state = self.state_dict()
|
20 |
+
# for name inshop_cos own_state:
|
21 |
+
# print name
|
22 |
+
new_state_dict = OrderedDict()
|
23 |
+
for name, param in state_dict.items():
|
24 |
+
name = name.replace('module.', '')
|
25 |
+
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name \
|
26 |
+
and 'transpose_graph' not in name and 'middle' not in name:
|
27 |
+
if 'featuremap_2_graph' in name:
|
28 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
29 |
+
else:
|
30 |
+
name = name.replace('graph','source_graph')
|
31 |
+
new_state_dict[name] = 0
|
32 |
+
if name not in own_state:
|
33 |
+
if 'num_batch' in name:
|
34 |
+
continue
|
35 |
+
print('unexpected key "{}" in state_dict'
|
36 |
+
.format(name))
|
37 |
+
continue
|
38 |
+
# if isinstance(param, own_state):
|
39 |
+
if isinstance(param, Parameter):
|
40 |
+
# backwards compatibility for serialized parameters
|
41 |
+
param = param.data
|
42 |
+
try:
|
43 |
+
own_state[name].copy_(param)
|
44 |
+
except:
|
45 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
46 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
47 |
+
name, own_state[name].size(), param.size()))
|
48 |
+
continue # i add inshop_cos 2018/02/01
|
49 |
+
own_state[name].copy_(param)
|
50 |
+
# print 'copying %s' %name
|
51 |
+
|
52 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
53 |
+
if len(missing) > 0:
|
54 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
55 |
+
|
56 |
+
def get_target_parameter(self):
|
57 |
+
l = []
|
58 |
+
other = []
|
59 |
+
for name, k in self.named_parameters():
|
60 |
+
if 'target' in name or 'semantic' in name:
|
61 |
+
l.append(k)
|
62 |
+
else:
|
63 |
+
other.append(k)
|
64 |
+
return l, other
|
65 |
+
|
66 |
+
def get_semantic_parameter(self):
|
67 |
+
l = []
|
68 |
+
for name, k in self.named_parameters():
|
69 |
+
if 'semantic' in name:
|
70 |
+
l.append(k)
|
71 |
+
return l
|
72 |
+
|
73 |
+
def get_source_parameter(self):
|
74 |
+
l = []
|
75 |
+
for name, k in self.named_parameters():
|
76 |
+
if 'source' in name:
|
77 |
+
l.append(k)
|
78 |
+
return l
|
79 |
+
|
80 |
+
def top_forward(self, input, adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
81 |
+
x, low_level_features = self.xception_features(input)
|
82 |
+
# print(x.size())
|
83 |
+
x1 = self.aspp1(x)
|
84 |
+
x2 = self.aspp2(x)
|
85 |
+
x3 = self.aspp3(x)
|
86 |
+
x4 = self.aspp4(x)
|
87 |
+
x5 = self.global_avg_pool(x)
|
88 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
89 |
+
|
90 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
91 |
+
|
92 |
+
x = self.concat_projection_conv1(x)
|
93 |
+
x = self.concat_projection_bn1(x)
|
94 |
+
x = self.relu(x)
|
95 |
+
# print(x.size())
|
96 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
97 |
+
|
98 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
99 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
100 |
+
low_level_features = self.relu(low_level_features)
|
101 |
+
# print(low_level_features.size())
|
102 |
+
# print(x.size())
|
103 |
+
x = torch.cat((x, low_level_features), dim=1)
|
104 |
+
x = self.decoder(x)
|
105 |
+
|
106 |
+
### source graph
|
107 |
+
source_graph = self.source_featuremap_2_graph(x)
|
108 |
+
|
109 |
+
source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
|
110 |
+
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
|
111 |
+
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
|
112 |
+
|
113 |
+
### target source
|
114 |
+
graph = self.target_featuremap_2_graph(x)
|
115 |
+
|
116 |
+
# graph combine
|
117 |
+
# print(graph.size(),source_2_target_graph.size())
|
118 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
119 |
+
# print(graph.size())
|
120 |
+
|
121 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
122 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
123 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
124 |
+
|
125 |
+
|
126 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
127 |
+
x, low_level_features = self.xception_features(input)
|
128 |
+
# print(x.size())
|
129 |
+
x1 = self.aspp1(x)
|
130 |
+
x2 = self.aspp2(x)
|
131 |
+
x3 = self.aspp3(x)
|
132 |
+
x4 = self.aspp4(x)
|
133 |
+
x5 = self.global_avg_pool(x)
|
134 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
135 |
+
|
136 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
137 |
+
|
138 |
+
x = self.concat_projection_conv1(x)
|
139 |
+
x = self.concat_projection_bn1(x)
|
140 |
+
x = self.relu(x)
|
141 |
+
# print(x.size())
|
142 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
143 |
+
|
144 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
145 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
146 |
+
low_level_features = self.relu(low_level_features)
|
147 |
+
# print(low_level_features.size())
|
148 |
+
# print(x.size())
|
149 |
+
x = torch.cat((x, low_level_features), dim=1)
|
150 |
+
x = self.decoder(x)
|
151 |
+
|
152 |
+
### add graph
|
153 |
+
|
154 |
+
|
155 |
+
# target graph
|
156 |
+
# print('x size',x.size(),adj1.size())
|
157 |
+
graph = self.target_featuremap_2_graph(x)
|
158 |
+
|
159 |
+
# graph combine
|
160 |
+
# print(graph.size(),source_2_target_graph.size())
|
161 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
162 |
+
# print(graph.size())
|
163 |
+
|
164 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
165 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
166 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
167 |
+
# print(graph.size(),x.size())
|
168 |
+
# graph = self.gcn_encode.forward(graph,relu=True)
|
169 |
+
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
170 |
+
# graph = self.gcn_decode.forward(graph,relu=True)
|
171 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
172 |
+
x = self.target_skip_conv(x)
|
173 |
+
x = x + graph
|
174 |
+
|
175 |
+
###
|
176 |
+
x = self.semantic(x)
|
177 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
178 |
+
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class deeplab_xception_transfer_basemodel_savememory_synbn(deeplab_xception_synBN.DeepLabv3_plus):
|
183 |
+
def __init__(self, nInputChannels=3, n_classes=7, os=16, input_channels=256, hidden_layers=128, out_channels=256,
|
184 |
+
source_classes=20, transfer_graph=None):
|
185 |
+
super(deeplab_xception_transfer_basemodel_savememory_synbn, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
|
186 |
+
os=os,)
|
187 |
+
|
188 |
+
|
189 |
+
def load_source_model(self,state_dict):
|
190 |
+
own_state = self.state_dict()
|
191 |
+
# for name inshop_cos own_state:
|
192 |
+
# print name
|
193 |
+
new_state_dict = OrderedDict()
|
194 |
+
for name, param in state_dict.items():
|
195 |
+
name = name.replace('module.', '')
|
196 |
+
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name \
|
197 |
+
and 'transpose_graph' not in name and 'middle' not in name:
|
198 |
+
if 'featuremap_2_graph' in name:
|
199 |
+
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
|
200 |
+
else:
|
201 |
+
name = name.replace('graph','source_graph')
|
202 |
+
new_state_dict[name] = 0
|
203 |
+
if name not in own_state:
|
204 |
+
if 'num_batch' in name:
|
205 |
+
continue
|
206 |
+
print('unexpected key "{}" in state_dict'
|
207 |
+
.format(name))
|
208 |
+
continue
|
209 |
+
# if isinstance(param, own_state):
|
210 |
+
if isinstance(param, Parameter):
|
211 |
+
# backwards compatibility for serialized parameters
|
212 |
+
param = param.data
|
213 |
+
try:
|
214 |
+
own_state[name].copy_(param)
|
215 |
+
except:
|
216 |
+
print('While copying the parameter named {}, whose dimensions in the model are'
|
217 |
+
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
218 |
+
name, own_state[name].size(), param.size()))
|
219 |
+
continue # i add inshop_cos 2018/02/01
|
220 |
+
own_state[name].copy_(param)
|
221 |
+
# print 'copying %s' %name
|
222 |
+
|
223 |
+
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
224 |
+
if len(missing) > 0:
|
225 |
+
print('missing keys in state_dict: "{}"'.format(missing))
|
226 |
+
|
227 |
+
def get_target_parameter(self):
|
228 |
+
l = []
|
229 |
+
other = []
|
230 |
+
for name, k in self.named_parameters():
|
231 |
+
if 'target' in name or 'semantic' in name:
|
232 |
+
l.append(k)
|
233 |
+
else:
|
234 |
+
other.append(k)
|
235 |
+
return l, other
|
236 |
+
|
237 |
+
def get_semantic_parameter(self):
|
238 |
+
l = []
|
239 |
+
for name, k in self.named_parameters():
|
240 |
+
if 'semantic' in name:
|
241 |
+
l.append(k)
|
242 |
+
return l
|
243 |
+
|
244 |
+
def get_source_parameter(self):
|
245 |
+
l = []
|
246 |
+
for name, k in self.named_parameters():
|
247 |
+
if 'source' in name:
|
248 |
+
l.append(k)
|
249 |
+
return l
|
250 |
+
|
251 |
+
def top_forward(self, input, adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
252 |
+
x, low_level_features = self.xception_features(input)
|
253 |
+
# print(x.size())
|
254 |
+
x1 = self.aspp1(x)
|
255 |
+
x2 = self.aspp2(x)
|
256 |
+
x3 = self.aspp3(x)
|
257 |
+
x4 = self.aspp4(x)
|
258 |
+
x5 = self.global_avg_pool(x)
|
259 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
260 |
+
|
261 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
262 |
+
|
263 |
+
x = self.concat_projection_conv1(x)
|
264 |
+
x = self.concat_projection_bn1(x)
|
265 |
+
x = self.relu(x)
|
266 |
+
# print(x.size())
|
267 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
268 |
+
|
269 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
270 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
271 |
+
low_level_features = self.relu(low_level_features)
|
272 |
+
# print(low_level_features.size())
|
273 |
+
# print(x.size())
|
274 |
+
x = torch.cat((x, low_level_features), dim=1)
|
275 |
+
x = self.decoder(x)
|
276 |
+
|
277 |
+
### source graph
|
278 |
+
source_graph = self.source_featuremap_2_graph(x)
|
279 |
+
|
280 |
+
source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
|
281 |
+
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
|
282 |
+
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
|
283 |
+
|
284 |
+
### target source
|
285 |
+
graph = self.target_featuremap_2_graph(x)
|
286 |
+
|
287 |
+
# graph combine
|
288 |
+
# print(graph.size(),source_2_target_graph.size())
|
289 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
290 |
+
# print(graph.size())
|
291 |
+
|
292 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
293 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
294 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
295 |
+
|
296 |
+
|
297 |
+
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
|
298 |
+
x, low_level_features = self.xception_features(input)
|
299 |
+
# print(x.size())
|
300 |
+
x1 = self.aspp1(x)
|
301 |
+
x2 = self.aspp2(x)
|
302 |
+
x3 = self.aspp3(x)
|
303 |
+
x4 = self.aspp4(x)
|
304 |
+
x5 = self.global_avg_pool(x)
|
305 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
306 |
+
|
307 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
308 |
+
|
309 |
+
x = self.concat_projection_conv1(x)
|
310 |
+
x = self.concat_projection_bn1(x)
|
311 |
+
x = self.relu(x)
|
312 |
+
# print(x.size())
|
313 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
314 |
+
|
315 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
316 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
317 |
+
low_level_features = self.relu(low_level_features)
|
318 |
+
# print(low_level_features.size())
|
319 |
+
# print(x.size())
|
320 |
+
x = torch.cat((x, low_level_features), dim=1)
|
321 |
+
x = self.decoder(x)
|
322 |
+
|
323 |
+
### add graph
|
324 |
+
|
325 |
+
|
326 |
+
# target graph
|
327 |
+
# print('x size',x.size(),adj1.size())
|
328 |
+
graph = self.target_featuremap_2_graph(x)
|
329 |
+
|
330 |
+
# graph combine
|
331 |
+
# print(graph.size(),source_2_target_graph.size())
|
332 |
+
# graph = self.fc_graph.forward(graph,relu=True)
|
333 |
+
# print(graph.size())
|
334 |
+
|
335 |
+
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
336 |
+
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
337 |
+
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
338 |
+
# print(graph.size(),x.size())
|
339 |
+
# graph = self.gcn_encode.forward(graph,relu=True)
|
340 |
+
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
341 |
+
# graph = self.gcn_decode.forward(graph,relu=True)
|
342 |
+
graph = self.target_graph_2_fea.forward(graph, x)
|
343 |
+
x = self.target_skip_conv(x)
|
344 |
+
x = x + graph
|
345 |
+
|
346 |
+
###
|
347 |
+
x = self.semantic(x)
|
348 |
+
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
349 |
+
|
350 |
+
return x
|
351 |
+
|
352 |
+
|
353 |
+
class deeplab_xception_end2end_3d(deeplab_xception_transfer_basemodel_savememory):
|
354 |
+
def __init__(self, nInputChannels=3, n_classes=20, os=16, input_channels=256, hidden_layers=128, out_channels=256,
|
355 |
+
source_classes=7, middle_classes=18, transfer_graph=None):
|
356 |
+
super(deeplab_xception_end2end_3d, self).__init__(nInputChannels=nInputChannels,
|
357 |
+
n_classes=n_classes,
|
358 |
+
os=os, )
|
359 |
+
### source graph
|
360 |
+
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
|
361 |
+
hidden_layers=hidden_layers,
|
362 |
+
nodes=source_classes)
|
363 |
+
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
364 |
+
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
365 |
+
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
366 |
+
|
367 |
+
self.source_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
|
368 |
+
output_channels=out_channels,
|
369 |
+
hidden_layers=hidden_layers, nodes=source_classes
|
370 |
+
)
|
371 |
+
self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
372 |
+
nn.ReLU(True)])
|
373 |
+
self.source_semantic = nn.Conv2d(out_channels,source_classes,1)
|
374 |
+
self.middle_semantic = nn.Conv2d(out_channels, middle_classes, 1)
|
375 |
+
|
376 |
+
### target graph 1
|
377 |
+
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
|
378 |
+
hidden_layers=hidden_layers,
|
379 |
+
nodes=n_classes)
|
380 |
+
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
381 |
+
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
382 |
+
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
383 |
+
|
384 |
+
self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
|
385 |
+
output_channels=out_channels,
|
386 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
387 |
+
)
|
388 |
+
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
389 |
+
nn.ReLU(True)])
|
390 |
+
|
391 |
+
### middle
|
392 |
+
self.middle_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
|
393 |
+
hidden_layers=hidden_layers,
|
394 |
+
nodes=middle_classes)
|
395 |
+
self.middle_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
396 |
+
self.middle_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
397 |
+
self.middle_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
398 |
+
|
399 |
+
self.middle_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
|
400 |
+
output_channels=out_channels,
|
401 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
402 |
+
)
|
403 |
+
self.middle_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
404 |
+
nn.ReLU(True)])
|
405 |
+
|
406 |
+
### multi transpose
|
407 |
+
self.transpose_graph_source2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
408 |
+
adj=transfer_graph,
|
409 |
+
begin_nodes=source_classes, end_nodes=n_classes)
|
410 |
+
self.transpose_graph_target2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
411 |
+
adj=transfer_graph,
|
412 |
+
begin_nodes=n_classes, end_nodes=source_classes)
|
413 |
+
|
414 |
+
self.transpose_graph_middle2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
415 |
+
adj=transfer_graph,
|
416 |
+
begin_nodes=middle_classes, end_nodes=source_classes)
|
417 |
+
self.transpose_graph_middle2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
418 |
+
adj=transfer_graph,
|
419 |
+
begin_nodes=middle_classes, end_nodes=source_classes)
|
420 |
+
|
421 |
+
self.transpose_graph_source2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
422 |
+
adj=transfer_graph,
|
423 |
+
begin_nodes=source_classes, end_nodes=middle_classes)
|
424 |
+
self.transpose_graph_target2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
425 |
+
adj=transfer_graph,
|
426 |
+
begin_nodes=n_classes, end_nodes=middle_classes)
|
427 |
+
|
428 |
+
|
429 |
+
self.fc_graph_source = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
|
430 |
+
self.fc_graph_target = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
|
431 |
+
self.fc_graph_middle = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
|
432 |
+
|
433 |
+
def freeze_totally_bn(self):
|
434 |
+
for m in self.modules():
|
435 |
+
if isinstance(m, nn.BatchNorm2d):
|
436 |
+
m.eval()
|
437 |
+
m.weight.requires_grad = False
|
438 |
+
m.bias.requires_grad = False
|
439 |
+
|
440 |
+
def freeze_backbone_bn(self):
|
441 |
+
for m in self.xception_features.modules():
|
442 |
+
if isinstance(m, nn.BatchNorm2d):
|
443 |
+
m.eval()
|
444 |
+
m.weight.requires_grad = False
|
445 |
+
m.bias.requires_grad = False
|
446 |
+
|
447 |
+
def top_forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer_s2t=None, adj3_transfer_t2s=None,
|
448 |
+
adj4_middle=None,adj5_transfer_s2m=None,adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
|
449 |
+
x, low_level_features = self.xception_features(input)
|
450 |
+
# print(x.size())
|
451 |
+
x1 = self.aspp1(x)
|
452 |
+
x2 = self.aspp2(x)
|
453 |
+
x3 = self.aspp3(x)
|
454 |
+
x4 = self.aspp4(x)
|
455 |
+
x5 = self.global_avg_pool(x)
|
456 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
457 |
+
|
458 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
459 |
+
|
460 |
+
x = self.concat_projection_conv1(x)
|
461 |
+
x = self.concat_projection_bn1(x)
|
462 |
+
x = self.relu(x)
|
463 |
+
# print(x.size())
|
464 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
465 |
+
|
466 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
467 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
468 |
+
low_level_features = self.relu(low_level_features)
|
469 |
+
# print(low_level_features.size())
|
470 |
+
# print(x.size())
|
471 |
+
x = torch.cat((x, low_level_features), dim=1)
|
472 |
+
x = self.decoder(x)
|
473 |
+
|
474 |
+
### source graph
|
475 |
+
source_graph = self.source_featuremap_2_graph(x)
|
476 |
+
### target source
|
477 |
+
target_graph = self.target_featuremap_2_graph(x)
|
478 |
+
### middle source
|
479 |
+
middle_graph = self.middle_featuremap_2_graph(x)
|
480 |
+
|
481 |
+
##### end2end multi task
|
482 |
+
|
483 |
+
### first task
|
484 |
+
# print(source_graph.size(),target_graph.size())
|
485 |
+
source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
|
486 |
+
target_graph1 = self.target_graph_conv1.forward(target_graph, adj=adj1_target, relu=True)
|
487 |
+
middle_graph1 = self.target_graph_conv1.forward(middle_graph, adj=adj4_middle, relu=True)
|
488 |
+
|
489 |
+
# source 2 target & middle
|
490 |
+
source_2_target_graph1_v5 = self.transpose_graph_source2target.forward(source_graph1, adj=adj3_transfer_s2t,
|
491 |
+
relu=True)
|
492 |
+
source_2_middle_graph1_v5 = self.transpose_graph_source2middle.forward(source_graph1,adj=adj5_transfer_s2m,
|
493 |
+
relu=True)
|
494 |
+
# target 2 source & middle
|
495 |
+
target_2_source_graph1_v5 = self.transpose_graph_target2source.forward(target_graph1, adj=adj3_transfer_t2s,
|
496 |
+
relu=True)
|
497 |
+
target_2_middle_graph1_v5 = self.transpose_graph_target2middle.forward(target_graph1, adj=adj6_transfer_t2m,
|
498 |
+
relu=True)
|
499 |
+
# middle 2 source & target
|
500 |
+
middle_2_source_graph1_v5 = self.transpose_graph_middle2source.forward(middle_graph1, adj=adj5_transfer_m2s,
|
501 |
+
relu=True)
|
502 |
+
middle_2_target_graph1_v5 = self.transpose_graph_middle2target.forward(middle_graph1, adj=adj6_transfer_m2t,
|
503 |
+
relu=True)
|
504 |
+
# source 2 middle target
|
505 |
+
source_2_target_graph1 = self.similarity_trans(source_graph1, target_graph1)
|
506 |
+
source_2_middle_graph1 = self.similarity_trans(source_graph1, middle_graph1)
|
507 |
+
# target 2 source middle
|
508 |
+
target_2_source_graph1 = self.similarity_trans(target_graph1, source_graph1)
|
509 |
+
target_2_middle_graph1 = self.similarity_trans(target_graph1, middle_graph1)
|
510 |
+
# middle 2 source target
|
511 |
+
middle_2_source_graph1 = self.similarity_trans(middle_graph1, source_graph1)
|
512 |
+
middle_2_target_graph1 = self.similarity_trans(middle_graph1, target_graph1)
|
513 |
+
|
514 |
+
## concat
|
515 |
+
# print(source_graph1.size(), target_2_source_graph1.size(), )
|
516 |
+
source_graph1 = torch.cat(
|
517 |
+
(source_graph1, target_2_source_graph1, target_2_source_graph1_v5,
|
518 |
+
middle_2_source_graph1, middle_2_source_graph1_v5), dim=-1)
|
519 |
+
source_graph1 = self.fc_graph_source.forward(source_graph1, relu=True)
|
520 |
+
# target
|
521 |
+
target_graph1 = torch.cat(
|
522 |
+
(target_graph1, source_2_target_graph1, source_2_target_graph1_v5,
|
523 |
+
middle_2_target_graph1, middle_2_target_graph1_v5), dim=-1)
|
524 |
+
target_graph1 = self.fc_graph_target.forward(target_graph1, relu=True)
|
525 |
+
# middle
|
526 |
+
middle_graph1 = torch.cat((middle_graph1, source_2_middle_graph1, source_2_middle_graph1_v5,
|
527 |
+
target_2_middle_graph1, target_2_middle_graph1_v5), dim=-1)
|
528 |
+
middle_graph1 = self.fc_graph_middle.forward(middle_graph1, relu=True)
|
529 |
+
|
530 |
+
|
531 |
+
### seconde task
|
532 |
+
source_graph2 = self.source_graph_conv1.forward(source_graph1, adj=adj2_source, relu=True)
|
533 |
+
target_graph2 = self.target_graph_conv1.forward(target_graph1, adj=adj1_target, relu=True)
|
534 |
+
middle_graph2 = self.target_graph_conv1.forward(middle_graph1, adj=adj4_middle, relu=True)
|
535 |
+
|
536 |
+
# source 2 target & middle
|
537 |
+
source_2_target_graph2_v5 = self.transpose_graph_source2target.forward(source_graph2, adj=adj3_transfer_s2t,
|
538 |
+
relu=True)
|
539 |
+
source_2_middle_graph2_v5 = self.transpose_graph_source2middle.forward(source_graph2, adj=adj5_transfer_s2m,
|
540 |
+
relu=True)
|
541 |
+
# target 2 source & middle
|
542 |
+
target_2_source_graph2_v5 = self.transpose_graph_target2source.forward(target_graph2, adj=adj3_transfer_t2s,
|
543 |
+
relu=True)
|
544 |
+
target_2_middle_graph2_v5 = self.transpose_graph_target2middle.forward(target_graph2, adj=adj6_transfer_t2m,
|
545 |
+
relu=True)
|
546 |
+
# middle 2 source & target
|
547 |
+
middle_2_source_graph2_v5 = self.transpose_graph_middle2source.forward(middle_graph2, adj=adj5_transfer_m2s,
|
548 |
+
relu=True)
|
549 |
+
middle_2_target_graph2_v5 = self.transpose_graph_middle2target.forward(middle_graph2, adj=adj6_transfer_m2t,
|
550 |
+
relu=True)
|
551 |
+
# source 2 middle target
|
552 |
+
source_2_target_graph2 = self.similarity_trans(source_graph2, target_graph2)
|
553 |
+
source_2_middle_graph2 = self.similarity_trans(source_graph2, middle_graph2)
|
554 |
+
# target 2 source middle
|
555 |
+
target_2_source_graph2 = self.similarity_trans(target_graph2, source_graph2)
|
556 |
+
target_2_middle_graph2 = self.similarity_trans(target_graph2, middle_graph2)
|
557 |
+
# middle 2 source target
|
558 |
+
middle_2_source_graph2 = self.similarity_trans(middle_graph2, source_graph2)
|
559 |
+
middle_2_target_graph2 = self.similarity_trans(middle_graph2, target_graph2)
|
560 |
+
|
561 |
+
## concat
|
562 |
+
# print(source_graph1.size(), target_2_source_graph1.size(), )
|
563 |
+
source_graph2 = torch.cat(
|
564 |
+
(source_graph2, target_2_source_graph2, target_2_source_graph2_v5,
|
565 |
+
middle_2_source_graph2, middle_2_source_graph2_v5), dim=-1)
|
566 |
+
source_graph2 = self.fc_graph_source.forward(source_graph2, relu=True)
|
567 |
+
# target
|
568 |
+
target_graph2 = torch.cat(
|
569 |
+
(target_graph2, source_2_target_graph2, source_2_target_graph2_v5,
|
570 |
+
middle_2_target_graph2, middle_2_target_graph2_v5), dim=-1)
|
571 |
+
target_graph2 = self.fc_graph_target.forward(target_graph2, relu=True)
|
572 |
+
# middle
|
573 |
+
middle_graph2 = torch.cat((middle_graph2, source_2_middle_graph2, source_2_middle_graph2_v5,
|
574 |
+
target_2_middle_graph2, target_2_middle_graph2_v5), dim=-1)
|
575 |
+
middle_graph2 = self.fc_graph_middle.forward(middle_graph2, relu=True)
|
576 |
+
|
577 |
+
|
578 |
+
### third task
|
579 |
+
source_graph3 = self.source_graph_conv1.forward(source_graph2, adj=adj2_source, relu=True)
|
580 |
+
target_graph3 = self.target_graph_conv1.forward(target_graph2, adj=adj1_target, relu=True)
|
581 |
+
middle_graph3 = self.target_graph_conv1.forward(middle_graph2, adj=adj4_middle, relu=True)
|
582 |
+
|
583 |
+
# source 2 target & middle
|
584 |
+
source_2_target_graph3_v5 = self.transpose_graph_source2target.forward(source_graph3, adj=adj3_transfer_s2t,
|
585 |
+
relu=True)
|
586 |
+
source_2_middle_graph3_v5 = self.transpose_graph_source2middle.forward(source_graph3, adj=adj5_transfer_s2m,
|
587 |
+
relu=True)
|
588 |
+
# target 2 source & middle
|
589 |
+
target_2_source_graph3_v5 = self.transpose_graph_target2source.forward(target_graph3, adj=adj3_transfer_t2s,
|
590 |
+
relu=True)
|
591 |
+
target_2_middle_graph3_v5 = self.transpose_graph_target2middle.forward(target_graph3, adj=adj6_transfer_t2m,
|
592 |
+
relu=True)
|
593 |
+
# middle 2 source & target
|
594 |
+
middle_2_source_graph3_v5 = self.transpose_graph_middle2source.forward(middle_graph3, adj=adj5_transfer_m2s,
|
595 |
+
relu=True)
|
596 |
+
middle_2_target_graph3_v5 = self.transpose_graph_middle2target.forward(middle_graph3, adj=adj6_transfer_m2t,
|
597 |
+
relu=True)
|
598 |
+
# source 2 middle target
|
599 |
+
source_2_target_graph3 = self.similarity_trans(source_graph3, target_graph3)
|
600 |
+
source_2_middle_graph3 = self.similarity_trans(source_graph3, middle_graph3)
|
601 |
+
# target 2 source middle
|
602 |
+
target_2_source_graph3 = self.similarity_trans(target_graph3, source_graph3)
|
603 |
+
target_2_middle_graph3 = self.similarity_trans(target_graph3, middle_graph3)
|
604 |
+
# middle 2 source target
|
605 |
+
middle_2_source_graph3 = self.similarity_trans(middle_graph3, source_graph3)
|
606 |
+
middle_2_target_graph3 = self.similarity_trans(middle_graph3, target_graph3)
|
607 |
+
|
608 |
+
## concat
|
609 |
+
# print(source_graph1.size(), target_2_source_graph1.size(), )
|
610 |
+
source_graph3 = torch.cat(
|
611 |
+
(source_graph3, target_2_source_graph3, target_2_source_graph3_v5,
|
612 |
+
middle_2_source_graph3, middle_2_source_graph3_v5), dim=-1)
|
613 |
+
source_graph3 = self.fc_graph_source.forward(source_graph3, relu=True)
|
614 |
+
# target
|
615 |
+
target_graph3 = torch.cat(
|
616 |
+
(target_graph3, source_2_target_graph3, source_2_target_graph3_v5,
|
617 |
+
middle_2_target_graph3, middle_2_target_graph3_v5), dim=-1)
|
618 |
+
target_graph3 = self.fc_graph_target.forward(target_graph3, relu=True)
|
619 |
+
# middle
|
620 |
+
middle_graph3 = torch.cat((middle_graph3, source_2_middle_graph3, source_2_middle_graph3_v5,
|
621 |
+
target_2_middle_graph3, target_2_middle_graph3_v5), dim=-1)
|
622 |
+
middle_graph3 = self.fc_graph_middle.forward(middle_graph3, relu=True)
|
623 |
+
|
624 |
+
return source_graph3, target_graph3, middle_graph3, x
|
625 |
+
|
626 |
+
def similarity_trans(self,source,target):
|
627 |
+
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
|
628 |
+
sim = F.softmax(sim, dim=-1)
|
629 |
+
return torch.matmul(sim, source)
|
630 |
+
|
631 |
+
def bottom_forward_source(self, input, source_graph):
|
632 |
+
# print('input size')
|
633 |
+
# print(input.size())
|
634 |
+
# print(source_graph.size())
|
635 |
+
graph = self.source_graph_2_fea.forward(source_graph, input)
|
636 |
+
x = self.source_skip_conv(input)
|
637 |
+
x = x + graph
|
638 |
+
x = self.source_semantic(x)
|
639 |
+
return x
|
640 |
+
|
641 |
+
def bottom_forward_target(self, input, target_graph):
|
642 |
+
graph = self.target_graph_2_fea.forward(target_graph, input)
|
643 |
+
x = self.target_skip_conv(input)
|
644 |
+
x = x + graph
|
645 |
+
x = self.semantic(x)
|
646 |
+
return x
|
647 |
+
|
648 |
+
def bottom_forward_middle(self, input, target_graph):
|
649 |
+
graph = self.middle_graph_2_fea.forward(target_graph, input)
|
650 |
+
x = self.middle_skip_conv(input)
|
651 |
+
x = x + graph
|
652 |
+
x = self.middle_semantic(x)
|
653 |
+
return x
|
654 |
+
|
655 |
+
def forward(self, input_source, input_target=None, input_middle=None, adj1_target=None, adj2_source=None,
|
656 |
+
adj3_transfer_s2t=None, adj3_transfer_t2s=None, adj4_middle=None,adj5_transfer_s2m=None,
|
657 |
+
adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
|
658 |
+
if input_source is None and input_target is not None and input_middle is None:
|
659 |
+
# target
|
660 |
+
target_batch = input_target.size(0)
|
661 |
+
input = input_target
|
662 |
+
|
663 |
+
source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, adj2_source=adj2_source,
|
664 |
+
adj3_transfer_s2t=adj3_transfer_s2t,
|
665 |
+
adj3_transfer_t2s=adj3_transfer_t2s,
|
666 |
+
adj4_middle=adj4_middle,
|
667 |
+
adj5_transfer_s2m=adj5_transfer_s2m,
|
668 |
+
adj6_transfer_t2m=adj6_transfer_t2m,
|
669 |
+
adj5_transfer_m2s=adj5_transfer_m2s,
|
670 |
+
adj6_transfer_m2t=adj6_transfer_m2t)
|
671 |
+
|
672 |
+
# source_x = self.bottom_forward_source(source_x, source_graph)
|
673 |
+
target_x = self.bottom_forward_target(x, target_graph)
|
674 |
+
|
675 |
+
target_x = F.upsample(target_x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
676 |
+
return None, target_x, None
|
677 |
+
|
678 |
+
if input_source is not None and input_target is None and input_middle is None:
|
679 |
+
# source
|
680 |
+
source_batch = input_source.size(0)
|
681 |
+
source_list = range(source_batch)
|
682 |
+
input = input_source
|
683 |
+
|
684 |
+
source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
|
685 |
+
adj2_source=adj2_source,
|
686 |
+
adj3_transfer_s2t=adj3_transfer_s2t,
|
687 |
+
adj3_transfer_t2s=adj3_transfer_t2s,
|
688 |
+
adj4_middle=adj4_middle,
|
689 |
+
adj5_transfer_s2m=adj5_transfer_s2m,
|
690 |
+
adj6_transfer_t2m=adj6_transfer_t2m,
|
691 |
+
adj5_transfer_m2s=adj5_transfer_m2s,
|
692 |
+
adj6_transfer_m2t=adj6_transfer_m2t)
|
693 |
+
|
694 |
+
source_x = self.bottom_forward_source(x, source_graph)
|
695 |
+
source_x = F.upsample(source_x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
696 |
+
return source_x, None, None
|
697 |
+
|
698 |
+
if input_middle is not None and input_source is None and input_target is None:
|
699 |
+
# middle
|
700 |
+
input = input_middle
|
701 |
+
|
702 |
+
source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
|
703 |
+
adj2_source=adj2_source,
|
704 |
+
adj3_transfer_s2t=adj3_transfer_s2t,
|
705 |
+
adj3_transfer_t2s=adj3_transfer_t2s,
|
706 |
+
adj4_middle=adj4_middle,
|
707 |
+
adj5_transfer_s2m=adj5_transfer_s2m,
|
708 |
+
adj6_transfer_t2m=adj6_transfer_t2m,
|
709 |
+
adj5_transfer_m2s=adj5_transfer_m2s,
|
710 |
+
adj6_transfer_m2t=adj6_transfer_m2t)
|
711 |
+
|
712 |
+
middle_x = self.bottom_forward_middle(x, source_graph)
|
713 |
+
middle_x = F.upsample(middle_x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
714 |
+
return None, None, middle_x
|
715 |
+
|
716 |
+
|
717 |
+
class deeplab_xception_end2end_3d_synbn(deeplab_xception_transfer_basemodel_savememory_synbn):
|
718 |
+
def __init__(self, nInputChannels=3, n_classes=20, os=16, input_channels=256, hidden_layers=128, out_channels=256,
|
719 |
+
source_classes=7, middle_classes=18, transfer_graph=None):
|
720 |
+
super(deeplab_xception_end2end_3d_synbn, self).__init__(nInputChannels=nInputChannels,
|
721 |
+
n_classes=n_classes,
|
722 |
+
os=os, )
|
723 |
+
### source graph
|
724 |
+
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
|
725 |
+
hidden_layers=hidden_layers,
|
726 |
+
nodes=source_classes)
|
727 |
+
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
728 |
+
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
729 |
+
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
730 |
+
|
731 |
+
self.source_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
|
732 |
+
output_channels=out_channels,
|
733 |
+
hidden_layers=hidden_layers, nodes=source_classes
|
734 |
+
)
|
735 |
+
self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
736 |
+
nn.ReLU(True)])
|
737 |
+
self.source_semantic = nn.Conv2d(out_channels,source_classes,1)
|
738 |
+
self.middle_semantic = nn.Conv2d(out_channels, middle_classes, 1)
|
739 |
+
|
740 |
+
### target graph 1
|
741 |
+
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
|
742 |
+
hidden_layers=hidden_layers,
|
743 |
+
nodes=n_classes)
|
744 |
+
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
745 |
+
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
746 |
+
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
747 |
+
|
748 |
+
self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
|
749 |
+
output_channels=out_channels,
|
750 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
751 |
+
)
|
752 |
+
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
753 |
+
nn.ReLU(True)])
|
754 |
+
|
755 |
+
### middle
|
756 |
+
self.middle_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
|
757 |
+
hidden_layers=hidden_layers,
|
758 |
+
nodes=middle_classes)
|
759 |
+
self.middle_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
760 |
+
self.middle_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
761 |
+
self.middle_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
762 |
+
|
763 |
+
self.middle_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
|
764 |
+
output_channels=out_channels,
|
765 |
+
hidden_layers=hidden_layers, nodes=n_classes
|
766 |
+
)
|
767 |
+
self.middle_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
768 |
+
nn.ReLU(True)])
|
769 |
+
|
770 |
+
### multi transpose
|
771 |
+
self.transpose_graph_source2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
772 |
+
adj=transfer_graph,
|
773 |
+
begin_nodes=source_classes, end_nodes=n_classes)
|
774 |
+
self.transpose_graph_target2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
775 |
+
adj=transfer_graph,
|
776 |
+
begin_nodes=n_classes, end_nodes=source_classes)
|
777 |
+
|
778 |
+
self.transpose_graph_middle2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
779 |
+
adj=transfer_graph,
|
780 |
+
begin_nodes=middle_classes, end_nodes=source_classes)
|
781 |
+
self.transpose_graph_middle2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
782 |
+
adj=transfer_graph,
|
783 |
+
begin_nodes=middle_classes, end_nodes=source_classes)
|
784 |
+
|
785 |
+
self.transpose_graph_source2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
786 |
+
adj=transfer_graph,
|
787 |
+
begin_nodes=source_classes, end_nodes=middle_classes)
|
788 |
+
self.transpose_graph_target2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
|
789 |
+
adj=transfer_graph,
|
790 |
+
begin_nodes=n_classes, end_nodes=middle_classes)
|
791 |
+
|
792 |
+
|
793 |
+
self.fc_graph_source = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
|
794 |
+
self.fc_graph_target = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
|
795 |
+
self.fc_graph_middle = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
|
796 |
+
|
797 |
+
|
798 |
+
def top_forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer_s2t=None, adj3_transfer_t2s=None,
|
799 |
+
adj4_middle=None,adj5_transfer_s2m=None,adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
|
800 |
+
x, low_level_features = self.xception_features(input)
|
801 |
+
# print(x.size())
|
802 |
+
x1 = self.aspp1(x)
|
803 |
+
x2 = self.aspp2(x)
|
804 |
+
x3 = self.aspp3(x)
|
805 |
+
x4 = self.aspp4(x)
|
806 |
+
x5 = self.global_avg_pool(x)
|
807 |
+
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
808 |
+
|
809 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
810 |
+
|
811 |
+
x = self.concat_projection_conv1(x)
|
812 |
+
x = self.concat_projection_bn1(x)
|
813 |
+
x = self.relu(x)
|
814 |
+
# print(x.size())
|
815 |
+
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
|
816 |
+
|
817 |
+
low_level_features = self.feature_projection_conv1(low_level_features)
|
818 |
+
low_level_features = self.feature_projection_bn1(low_level_features)
|
819 |
+
low_level_features = self.relu(low_level_features)
|
820 |
+
# print(low_level_features.size())
|
821 |
+
# print(x.size())
|
822 |
+
x = torch.cat((x, low_level_features), dim=1)
|
823 |
+
x = self.decoder(x)
|
824 |
+
|
825 |
+
### source graph
|
826 |
+
source_graph = self.source_featuremap_2_graph(x)
|
827 |
+
### target source
|
828 |
+
target_graph = self.target_featuremap_2_graph(x)
|
829 |
+
### middle source
|
830 |
+
middle_graph = self.middle_featuremap_2_graph(x)
|
831 |
+
|
832 |
+
##### end2end multi task
|
833 |
+
|
834 |
+
### first task
|
835 |
+
# print(source_graph.size(),target_graph.size())
|
836 |
+
source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
|
837 |
+
target_graph1 = self.target_graph_conv1.forward(target_graph, adj=adj1_target, relu=True)
|
838 |
+
middle_graph1 = self.target_graph_conv1.forward(middle_graph, adj=adj4_middle, relu=True)
|
839 |
+
|
840 |
+
# source 2 target & middle
|
841 |
+
source_2_target_graph1_v5 = self.transpose_graph_source2target.forward(source_graph1, adj=adj3_transfer_s2t,
|
842 |
+
relu=True)
|
843 |
+
source_2_middle_graph1_v5 = self.transpose_graph_source2middle.forward(source_graph1,adj=adj5_transfer_s2m,
|
844 |
+
relu=True)
|
845 |
+
# target 2 source & middle
|
846 |
+
target_2_source_graph1_v5 = self.transpose_graph_target2source.forward(target_graph1, adj=adj3_transfer_t2s,
|
847 |
+
relu=True)
|
848 |
+
target_2_middle_graph1_v5 = self.transpose_graph_target2middle.forward(target_graph1, adj=adj6_transfer_t2m,
|
849 |
+
relu=True)
|
850 |
+
# middle 2 source & target
|
851 |
+
middle_2_source_graph1_v5 = self.transpose_graph_middle2source.forward(middle_graph1, adj=adj5_transfer_m2s,
|
852 |
+
relu=True)
|
853 |
+
middle_2_target_graph1_v5 = self.transpose_graph_middle2target.forward(middle_graph1, adj=adj6_transfer_m2t,
|
854 |
+
relu=True)
|
855 |
+
# source 2 middle target
|
856 |
+
source_2_target_graph1 = self.similarity_trans(source_graph1, target_graph1)
|
857 |
+
source_2_middle_graph1 = self.similarity_trans(source_graph1, middle_graph1)
|
858 |
+
# target 2 source middle
|
859 |
+
target_2_source_graph1 = self.similarity_trans(target_graph1, source_graph1)
|
860 |
+
target_2_middle_graph1 = self.similarity_trans(target_graph1, middle_graph1)
|
861 |
+
# middle 2 source target
|
862 |
+
middle_2_source_graph1 = self.similarity_trans(middle_graph1, source_graph1)
|
863 |
+
middle_2_target_graph1 = self.similarity_trans(middle_graph1, target_graph1)
|
864 |
+
|
865 |
+
## concat
|
866 |
+
# print(source_graph1.size(), target_2_source_graph1.size(), )
|
867 |
+
source_graph1 = torch.cat(
|
868 |
+
(source_graph1, target_2_source_graph1, target_2_source_graph1_v5,
|
869 |
+
middle_2_source_graph1, middle_2_source_graph1_v5), dim=-1)
|
870 |
+
source_graph1 = self.fc_graph_source.forward(source_graph1, relu=True)
|
871 |
+
# target
|
872 |
+
target_graph1 = torch.cat(
|
873 |
+
(target_graph1, source_2_target_graph1, source_2_target_graph1_v5,
|
874 |
+
middle_2_target_graph1, middle_2_target_graph1_v5), dim=-1)
|
875 |
+
target_graph1 = self.fc_graph_target.forward(target_graph1, relu=True)
|
876 |
+
# middle
|
877 |
+
middle_graph1 = torch.cat((middle_graph1, source_2_middle_graph1, source_2_middle_graph1_v5,
|
878 |
+
target_2_middle_graph1, target_2_middle_graph1_v5), dim=-1)
|
879 |
+
middle_graph1 = self.fc_graph_middle.forward(middle_graph1, relu=True)
|
880 |
+
|
881 |
+
|
882 |
+
### seconde task
|
883 |
+
source_graph2 = self.source_graph_conv1.forward(source_graph1, adj=adj2_source, relu=True)
|
884 |
+
target_graph2 = self.target_graph_conv1.forward(target_graph1, adj=adj1_target, relu=True)
|
885 |
+
middle_graph2 = self.target_graph_conv1.forward(middle_graph1, adj=adj4_middle, relu=True)
|
886 |
+
|
887 |
+
# source 2 target & middle
|
888 |
+
source_2_target_graph2_v5 = self.transpose_graph_source2target.forward(source_graph2, adj=adj3_transfer_s2t,
|
889 |
+
relu=True)
|
890 |
+
source_2_middle_graph2_v5 = self.transpose_graph_source2middle.forward(source_graph2, adj=adj5_transfer_s2m,
|
891 |
+
relu=True)
|
892 |
+
# target 2 source & middle
|
893 |
+
target_2_source_graph2_v5 = self.transpose_graph_target2source.forward(target_graph2, adj=adj3_transfer_t2s,
|
894 |
+
relu=True)
|
895 |
+
target_2_middle_graph2_v5 = self.transpose_graph_target2middle.forward(target_graph2, adj=adj6_transfer_t2m,
|
896 |
+
relu=True)
|
897 |
+
# middle 2 source & target
|
898 |
+
middle_2_source_graph2_v5 = self.transpose_graph_middle2source.forward(middle_graph2, adj=adj5_transfer_m2s,
|
899 |
+
relu=True)
|
900 |
+
middle_2_target_graph2_v5 = self.transpose_graph_middle2target.forward(middle_graph2, adj=adj6_transfer_m2t,
|
901 |
+
relu=True)
|
902 |
+
# source 2 middle target
|
903 |
+
source_2_target_graph2 = self.similarity_trans(source_graph2, target_graph2)
|
904 |
+
source_2_middle_graph2 = self.similarity_trans(source_graph2, middle_graph2)
|
905 |
+
# target 2 source middle
|
906 |
+
target_2_source_graph2 = self.similarity_trans(target_graph2, source_graph2)
|
907 |
+
target_2_middle_graph2 = self.similarity_trans(target_graph2, middle_graph2)
|
908 |
+
# middle 2 source target
|
909 |
+
middle_2_source_graph2 = self.similarity_trans(middle_graph2, source_graph2)
|
910 |
+
middle_2_target_graph2 = self.similarity_trans(middle_graph2, target_graph2)
|
911 |
+
|
912 |
+
## concat
|
913 |
+
# print(source_graph1.size(), target_2_source_graph1.size(), )
|
914 |
+
source_graph2 = torch.cat(
|
915 |
+
(source_graph2, target_2_source_graph2, target_2_source_graph2_v5,
|
916 |
+
middle_2_source_graph2, middle_2_source_graph2_v5), dim=-1)
|
917 |
+
source_graph2 = self.fc_graph_source.forward(source_graph2, relu=True)
|
918 |
+
# target
|
919 |
+
target_graph2 = torch.cat(
|
920 |
+
(target_graph2, source_2_target_graph2, source_2_target_graph2_v5,
|
921 |
+
middle_2_target_graph2, middle_2_target_graph2_v5), dim=-1)
|
922 |
+
target_graph2 = self.fc_graph_target.forward(target_graph2, relu=True)
|
923 |
+
# middle
|
924 |
+
middle_graph2 = torch.cat((middle_graph2, source_2_middle_graph2, source_2_middle_graph2_v5,
|
925 |
+
target_2_middle_graph2, target_2_middle_graph2_v5), dim=-1)
|
926 |
+
middle_graph2 = self.fc_graph_middle.forward(middle_graph2, relu=True)
|
927 |
+
|
928 |
+
|
929 |
+
### third task
|
930 |
+
source_graph3 = self.source_graph_conv1.forward(source_graph2, adj=adj2_source, relu=True)
|
931 |
+
target_graph3 = self.target_graph_conv1.forward(target_graph2, adj=adj1_target, relu=True)
|
932 |
+
middle_graph3 = self.target_graph_conv1.forward(middle_graph2, adj=adj4_middle, relu=True)
|
933 |
+
|
934 |
+
# source 2 target & middle
|
935 |
+
source_2_target_graph3_v5 = self.transpose_graph_source2target.forward(source_graph3, adj=adj3_transfer_s2t,
|
936 |
+
relu=True)
|
937 |
+
source_2_middle_graph3_v5 = self.transpose_graph_source2middle.forward(source_graph3, adj=adj5_transfer_s2m,
|
938 |
+
relu=True)
|
939 |
+
# target 2 source & middle
|
940 |
+
target_2_source_graph3_v5 = self.transpose_graph_target2source.forward(target_graph3, adj=adj3_transfer_t2s,
|
941 |
+
relu=True)
|
942 |
+
target_2_middle_graph3_v5 = self.transpose_graph_target2middle.forward(target_graph3, adj=adj6_transfer_t2m,
|
943 |
+
relu=True)
|
944 |
+
# middle 2 source & target
|
945 |
+
middle_2_source_graph3_v5 = self.transpose_graph_middle2source.forward(middle_graph3, adj=adj5_transfer_m2s,
|
946 |
+
relu=True)
|
947 |
+
middle_2_target_graph3_v5 = self.transpose_graph_middle2target.forward(middle_graph3, adj=adj6_transfer_m2t,
|
948 |
+
relu=True)
|
949 |
+
# source 2 middle target
|
950 |
+
source_2_target_graph3 = self.similarity_trans(source_graph3, target_graph3)
|
951 |
+
source_2_middle_graph3 = self.similarity_trans(source_graph3, middle_graph3)
|
952 |
+
# target 2 source middle
|
953 |
+
target_2_source_graph3 = self.similarity_trans(target_graph3, source_graph3)
|
954 |
+
target_2_middle_graph3 = self.similarity_trans(target_graph3, middle_graph3)
|
955 |
+
# middle 2 source target
|
956 |
+
middle_2_source_graph3 = self.similarity_trans(middle_graph3, source_graph3)
|
957 |
+
middle_2_target_graph3 = self.similarity_trans(middle_graph3, target_graph3)
|
958 |
+
|
959 |
+
## concat
|
960 |
+
# print(source_graph1.size(), target_2_source_graph1.size(), )
|
961 |
+
source_graph3 = torch.cat(
|
962 |
+
(source_graph3, target_2_source_graph3, target_2_source_graph3_v5,
|
963 |
+
middle_2_source_graph3, middle_2_source_graph3_v5), dim=-1)
|
964 |
+
source_graph3 = self.fc_graph_source.forward(source_graph3, relu=True)
|
965 |
+
# target
|
966 |
+
target_graph3 = torch.cat(
|
967 |
+
(target_graph3, source_2_target_graph3, source_2_target_graph3_v5,
|
968 |
+
middle_2_target_graph3, middle_2_target_graph3_v5), dim=-1)
|
969 |
+
target_graph3 = self.fc_graph_target.forward(target_graph3, relu=True)
|
970 |
+
# middle
|
971 |
+
middle_graph3 = torch.cat((middle_graph3, source_2_middle_graph3, source_2_middle_graph3_v5,
|
972 |
+
target_2_middle_graph3, target_2_middle_graph3_v5), dim=-1)
|
973 |
+
middle_graph3 = self.fc_graph_middle.forward(middle_graph3, relu=True)
|
974 |
+
|
975 |
+
return source_graph3, target_graph3, middle_graph3, x
|
976 |
+
|
977 |
+
def similarity_trans(self,source,target):
|
978 |
+
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
|
979 |
+
sim = F.softmax(sim, dim=-1)
|
980 |
+
return torch.matmul(sim, source)
|
981 |
+
|
982 |
+
def bottom_forward_source(self, input, source_graph):
|
983 |
+
# print('input size')
|
984 |
+
# print(input.size())
|
985 |
+
# print(source_graph.size())
|
986 |
+
graph = self.source_graph_2_fea.forward(source_graph, input)
|
987 |
+
x = self.source_skip_conv(input)
|
988 |
+
x = x + graph
|
989 |
+
x = self.source_semantic(x)
|
990 |
+
return x
|
991 |
+
|
992 |
+
def bottom_forward_target(self, input, target_graph):
|
993 |
+
graph = self.target_graph_2_fea.forward(target_graph, input)
|
994 |
+
x = self.target_skip_conv(input)
|
995 |
+
x = x + graph
|
996 |
+
x = self.semantic(x)
|
997 |
+
return x
|
998 |
+
|
999 |
+
def bottom_forward_middle(self, input, target_graph):
|
1000 |
+
graph = self.middle_graph_2_fea.forward(target_graph, input)
|
1001 |
+
x = self.middle_skip_conv(input)
|
1002 |
+
x = x + graph
|
1003 |
+
x = self.middle_semantic(x)
|
1004 |
+
return x
|
1005 |
+
|
1006 |
+
def forward(self, input_source, input_target=None, input_middle=None, adj1_target=None, adj2_source=None,
|
1007 |
+
adj3_transfer_s2t=None, adj3_transfer_t2s=None, adj4_middle=None,adj5_transfer_s2m=None,
|
1008 |
+
adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
|
1009 |
+
|
1010 |
+
if input_source is None and input_target is not None and input_middle is None:
|
1011 |
+
# target
|
1012 |
+
target_batch = input_target.size(0)
|
1013 |
+
input = input_target
|
1014 |
+
|
1015 |
+
source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, adj2_source=adj2_source,
|
1016 |
+
adj3_transfer_s2t=adj3_transfer_s2t,
|
1017 |
+
adj3_transfer_t2s=adj3_transfer_t2s,
|
1018 |
+
adj4_middle=adj4_middle,
|
1019 |
+
adj5_transfer_s2m=adj5_transfer_s2m,
|
1020 |
+
adj6_transfer_t2m=adj6_transfer_t2m,
|
1021 |
+
adj5_transfer_m2s=adj5_transfer_m2s,
|
1022 |
+
adj6_transfer_m2t=adj6_transfer_m2t)
|
1023 |
+
|
1024 |
+
# source_x = self.bottom_forward_source(source_x, source_graph)
|
1025 |
+
target_x = self.bottom_forward_target(x, target_graph)
|
1026 |
+
|
1027 |
+
target_x = F.upsample(target_x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
1028 |
+
return None, target_x, None
|
1029 |
+
|
1030 |
+
if input_source is not None and input_target is None and input_middle is None:
|
1031 |
+
# source
|
1032 |
+
source_batch = input_source.size(0)
|
1033 |
+
source_list = range(source_batch)
|
1034 |
+
input = input_source
|
1035 |
+
|
1036 |
+
source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
|
1037 |
+
adj2_source=adj2_source,
|
1038 |
+
adj3_transfer_s2t=adj3_transfer_s2t,
|
1039 |
+
adj3_transfer_t2s=adj3_transfer_t2s,
|
1040 |
+
adj4_middle=adj4_middle,
|
1041 |
+
adj5_transfer_s2m=adj5_transfer_s2m,
|
1042 |
+
adj6_transfer_t2m=adj6_transfer_t2m,
|
1043 |
+
adj5_transfer_m2s=adj5_transfer_m2s,
|
1044 |
+
adj6_transfer_m2t=adj6_transfer_m2t)
|
1045 |
+
|
1046 |
+
source_x = self.bottom_forward_source(x, source_graph)
|
1047 |
+
source_x = F.upsample(source_x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
1048 |
+
return source_x, None, None
|
1049 |
+
|
1050 |
+
if input_middle is not None and input_source is None and input_target is None:
|
1051 |
+
# middle
|
1052 |
+
input = input_middle
|
1053 |
+
|
1054 |
+
source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
|
1055 |
+
adj2_source=adj2_source,
|
1056 |
+
adj3_transfer_s2t=adj3_transfer_s2t,
|
1057 |
+
adj3_transfer_t2s=adj3_transfer_t2s,
|
1058 |
+
adj4_middle=adj4_middle,
|
1059 |
+
adj5_transfer_s2m=adj5_transfer_s2m,
|
1060 |
+
adj6_transfer_t2m=adj6_transfer_t2m,
|
1061 |
+
adj5_transfer_m2s=adj5_transfer_m2s,
|
1062 |
+
adj6_transfer_m2t=adj6_transfer_m2t)
|
1063 |
+
|
1064 |
+
middle_x = self.bottom_forward_middle(x, source_graph)
|
1065 |
+
middle_x = F.upsample(middle_x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
1066 |
+
return None, None, middle_x
|
1067 |
+
|
1068 |
+
|
1069 |
+
if __name__ == '__main__':
|
1070 |
+
net = deeplab_xception_end2end_3d()
|
1071 |
+
net.freeze_totally_bn()
|
1072 |
+
img1 = torch.rand((1,3,128,128))
|
1073 |
+
img2 = torch.rand((1, 3, 128, 128))
|
1074 |
+
a1 = torch.ones((1,1,7,20))
|
1075 |
+
a2 = torch.ones((1,1,20,7))
|
1076 |
+
net.eval()
|
1077 |
+
net.forward(img1,img2,adj3_transfer_t2s=a2,adj3_transfer_s2t=a1)
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/gcn.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.nn.parameter import Parameter
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from networks import graph
|
7 |
+
# import pdb
|
8 |
+
|
9 |
+
class GraphConvolution(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self,in_features,out_features,bias=False):
|
12 |
+
super(GraphConvolution, self).__init__()
|
13 |
+
self.in_features = in_features
|
14 |
+
self.out_features = out_features
|
15 |
+
self.weight = Parameter(torch.FloatTensor(in_features,out_features))
|
16 |
+
if bias:
|
17 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
18 |
+
else:
|
19 |
+
self.register_parameter('bias',None)
|
20 |
+
self.reset_parameters()
|
21 |
+
|
22 |
+
def reset_parameters(self):
|
23 |
+
# stdv = 1./math.sqrt(self.weight(1))
|
24 |
+
# self.weight.data.uniform_(-stdv,stdv)
|
25 |
+
torch.nn.init.xavier_uniform_(self.weight)
|
26 |
+
# if self.bias is not None:
|
27 |
+
# self.bias.data.uniform_(-stdv,stdv)
|
28 |
+
|
29 |
+
def forward(self, input,adj=None,relu=False):
|
30 |
+
support = torch.matmul(input, self.weight)
|
31 |
+
# print(support.size(),adj.size())
|
32 |
+
if adj is not None:
|
33 |
+
output = torch.matmul(adj, support)
|
34 |
+
else:
|
35 |
+
output = support
|
36 |
+
# print(output.size())
|
37 |
+
if self.bias is not None:
|
38 |
+
return output + self.bias
|
39 |
+
else:
|
40 |
+
if relu:
|
41 |
+
return F.relu(output)
|
42 |
+
else:
|
43 |
+
return output
|
44 |
+
|
45 |
+
def __repr__(self):
|
46 |
+
return self.__class__.__name__ + ' (' \
|
47 |
+
+ str(self.in_features) + ' -> ' \
|
48 |
+
+ str(self.out_features) + ')'
|
49 |
+
|
50 |
+
class Featuremaps_to_Graph(nn.Module):
|
51 |
+
|
52 |
+
def __init__(self,input_channels,hidden_layers,nodes=7):
|
53 |
+
super(Featuremaps_to_Graph, self).__init__()
|
54 |
+
self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes))
|
55 |
+
self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers))
|
56 |
+
self.reset_parameters()
|
57 |
+
|
58 |
+
def forward(self, input):
|
59 |
+
n,c,h,w = input.size()
|
60 |
+
# print('fea input',input.size())
|
61 |
+
input1 = input.view(n,c,h*w)
|
62 |
+
input1 = input1.transpose(1,2) # n x hw x c
|
63 |
+
# print('fea input1', input1.size())
|
64 |
+
############## Feature maps to node ################
|
65 |
+
fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes
|
66 |
+
weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer
|
67 |
+
# softmax fea_node
|
68 |
+
fea_node = F.softmax(fea_node,dim=-1)
|
69 |
+
# print(fea_node.size(),weight_node.size())
|
70 |
+
graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node))
|
71 |
+
return graph_node # n x n_class x hidden_layer
|
72 |
+
|
73 |
+
def reset_parameters(self):
|
74 |
+
for ww in self.parameters():
|
75 |
+
torch.nn.init.xavier_uniform_(ww)
|
76 |
+
# if self.bias is not None:
|
77 |
+
# self.bias.data.uniform_(-stdv,stdv)
|
78 |
+
|
79 |
+
class Featuremaps_to_Graph_transfer(nn.Module):
|
80 |
+
|
81 |
+
def __init__(self,input_channels,hidden_layers,nodes=7, source_nodes=20):
|
82 |
+
super(Featuremaps_to_Graph_transfer, self).__init__()
|
83 |
+
self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes))
|
84 |
+
self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers))
|
85 |
+
self.pre_fea_transfer = nn.Sequential(*[nn.Linear(source_nodes, source_nodes),nn.LeakyReLU(True),
|
86 |
+
nn.Linear(source_nodes, nodes), nn.LeakyReLU(True)])
|
87 |
+
self.reset_parameters()
|
88 |
+
|
89 |
+
def forward(self, input, source_pre_fea):
|
90 |
+
self.pre_fea.data = self.pre_fea_learn(source_pre_fea)
|
91 |
+
n,c,h,w = input.size()
|
92 |
+
# print('fea input',input.size())
|
93 |
+
input1 = input.view(n,c,h*w)
|
94 |
+
input1 = input1.transpose(1,2) # n x hw x c
|
95 |
+
# print('fea input1', input1.size())
|
96 |
+
############## Feature maps to node ################
|
97 |
+
fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes
|
98 |
+
weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer
|
99 |
+
# softmax fea_node
|
100 |
+
fea_node = F.softmax(fea_node,dim=1)
|
101 |
+
# print(fea_node.size(),weight_node.size())
|
102 |
+
graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node))
|
103 |
+
return graph_node # n x n_class x hidden_layer
|
104 |
+
|
105 |
+
def pre_fea_learn(self, input):
|
106 |
+
pre_fea = self.pre_fea_transfer.forward(input.unsqueeze(0)).squeeze(0)
|
107 |
+
return self.pre_fea.data + pre_fea
|
108 |
+
|
109 |
+
class Graph_to_Featuremaps(nn.Module):
|
110 |
+
# this is a special version
|
111 |
+
def __init__(self,input_channels,output_channels,hidden_layers,nodes=7):
|
112 |
+
super(Graph_to_Featuremaps, self).__init__()
|
113 |
+
self.node_fea = Parameter(torch.FloatTensor(input_channels+hidden_layers,1))
|
114 |
+
self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels))
|
115 |
+
self.reset_parameters()
|
116 |
+
|
117 |
+
def reset_parameters(self):
|
118 |
+
for ww in self.parameters():
|
119 |
+
torch.nn.init.xavier_uniform_(ww)
|
120 |
+
|
121 |
+
def forward(self, input, res_feature):
|
122 |
+
'''
|
123 |
+
|
124 |
+
:param input: 1 x batch x nodes x hidden_layer
|
125 |
+
:param res_feature: batch x channels x h x w
|
126 |
+
:return:
|
127 |
+
'''
|
128 |
+
batchi,channeli,hi,wi = res_feature.size()
|
129 |
+
# print(res_feature.size())
|
130 |
+
# print(input.size())
|
131 |
+
try:
|
132 |
+
_,batch,nodes,hidden = input.size()
|
133 |
+
except:
|
134 |
+
# print(input.size())
|
135 |
+
input = input.unsqueeze(0)
|
136 |
+
_,batch, nodes, hidden = input.size()
|
137 |
+
|
138 |
+
assert batch == batchi
|
139 |
+
input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden)
|
140 |
+
res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2)
|
141 |
+
res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli)
|
142 |
+
new_fea = torch.cat((res_feature_after_view1,input1),dim=3)
|
143 |
+
|
144 |
+
# print(self.node_fea.size(),new_fea.size())
|
145 |
+
new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1
|
146 |
+
new_weight = torch.matmul(input, self.weight) # batch x node x channel
|
147 |
+
new_node = new_node.view(batch, hi*wi, nodes)
|
148 |
+
# 0721
|
149 |
+
new_node = F.softmax(new_node, dim=-1)
|
150 |
+
#
|
151 |
+
feature_out = torch.matmul(new_node,new_weight)
|
152 |
+
# print(feature_out.size())
|
153 |
+
feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size())
|
154 |
+
return F.relu(feature_out)
|
155 |
+
|
156 |
+
class Graph_to_Featuremaps_savemem(nn.Module):
|
157 |
+
# this is a special version for saving gpu memory. The process is same as Graph_to_Featuremaps.
|
158 |
+
def __init__(self, input_channels, output_channels, hidden_layers, nodes=7):
|
159 |
+
super(Graph_to_Featuremaps_savemem, self).__init__()
|
160 |
+
self.node_fea_for_res = Parameter(torch.FloatTensor(input_channels, 1))
|
161 |
+
self.node_fea_for_hidden = Parameter(torch.FloatTensor(hidden_layers, 1))
|
162 |
+
self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels))
|
163 |
+
self.reset_parameters()
|
164 |
+
|
165 |
+
def reset_parameters(self):
|
166 |
+
for ww in self.parameters():
|
167 |
+
torch.nn.init.xavier_uniform_(ww)
|
168 |
+
|
169 |
+
def forward(self, input, res_feature):
|
170 |
+
'''
|
171 |
+
|
172 |
+
:param input: 1 x batch x nodes x hidden_layer
|
173 |
+
:param res_feature: batch x channels x h x w
|
174 |
+
:return:
|
175 |
+
'''
|
176 |
+
batchi,channeli,hi,wi = res_feature.size()
|
177 |
+
# print(res_feature.size())
|
178 |
+
# print(input.size())
|
179 |
+
try:
|
180 |
+
_,batch,nodes,hidden = input.size()
|
181 |
+
except:
|
182 |
+
# print(input.size())
|
183 |
+
input = input.unsqueeze(0)
|
184 |
+
_,batch, nodes, hidden = input.size()
|
185 |
+
|
186 |
+
assert batch == batchi
|
187 |
+
input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden)
|
188 |
+
res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2)
|
189 |
+
res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli)
|
190 |
+
# new_fea = torch.cat((res_feature_after_view1,input1),dim=3)
|
191 |
+
## sim
|
192 |
+
new_node1 = torch.matmul(res_feature_after_view1, self.node_fea_for_res)
|
193 |
+
new_node2 = torch.matmul(input1, self.node_fea_for_hidden)
|
194 |
+
new_node = new_node1 + new_node2
|
195 |
+
## sim end
|
196 |
+
# print(self.node_fea.size(),new_fea.size())
|
197 |
+
# new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1
|
198 |
+
new_weight = torch.matmul(input, self.weight) # batch x node x channel
|
199 |
+
new_node = new_node.view(batch, hi*wi, nodes)
|
200 |
+
# 0721
|
201 |
+
new_node = F.softmax(new_node, dim=-1)
|
202 |
+
#
|
203 |
+
feature_out = torch.matmul(new_node,new_weight)
|
204 |
+
# print(feature_out.size())
|
205 |
+
feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size())
|
206 |
+
return F.relu(feature_out)
|
207 |
+
|
208 |
+
|
209 |
+
class Graph_trans(nn.Module):
|
210 |
+
|
211 |
+
def __init__(self,in_features,out_features,begin_nodes=7,end_nodes=2,bias=False,adj=None):
|
212 |
+
super(Graph_trans, self).__init__()
|
213 |
+
self.in_features = in_features
|
214 |
+
self.out_features = out_features
|
215 |
+
self.weight = Parameter(torch.FloatTensor(in_features,out_features))
|
216 |
+
if adj is not None:
|
217 |
+
h,w = adj.size()
|
218 |
+
assert (h == end_nodes) and (w == begin_nodes)
|
219 |
+
self.adj = torch.autograd.Variable(adj,requires_grad=False)
|
220 |
+
else:
|
221 |
+
self.adj = Parameter(torch.FloatTensor(end_nodes,begin_nodes))
|
222 |
+
if bias:
|
223 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
224 |
+
else:
|
225 |
+
self.register_parameter('bias',None)
|
226 |
+
# self.reset_parameters()
|
227 |
+
|
228 |
+
def reset_parameters(self):
|
229 |
+
# stdv = 1./math.sqrt(self.weight(1))
|
230 |
+
# self.weight.data.uniform_(-stdv,stdv)
|
231 |
+
torch.nn.init.xavier_uniform_(self.weight)
|
232 |
+
# if self.bias is not None:
|
233 |
+
# self.bias.data.uniform_(-stdv,stdv)
|
234 |
+
|
235 |
+
def forward(self, input, relu=False, adj_return=False, adj=None):
|
236 |
+
support = torch.matmul(input,self.weight)
|
237 |
+
# print(support.size(),self.adj.size())
|
238 |
+
if adj is None:
|
239 |
+
adj = self.adj
|
240 |
+
adj1 = self.norm_trans_adj(adj)
|
241 |
+
output = torch.matmul(adj1,support)
|
242 |
+
if adj_return:
|
243 |
+
output1 = F.normalize(output,p=2,dim=-1)
|
244 |
+
self.adj_mat = torch.matmul(output1,output1.transpose(-2,-1))
|
245 |
+
if self.bias is not None:
|
246 |
+
return output + self.bias
|
247 |
+
else:
|
248 |
+
if relu:
|
249 |
+
return F.relu(output)
|
250 |
+
else:
|
251 |
+
return output
|
252 |
+
|
253 |
+
def get_adj_mat(self):
|
254 |
+
adj = graph.normalize_adj_torch(F.relu(self.adj_mat))
|
255 |
+
return adj
|
256 |
+
|
257 |
+
def get_encode_adj(self):
|
258 |
+
return self.adj
|
259 |
+
|
260 |
+
def norm_trans_adj(self,adj): # maybe can use softmax
|
261 |
+
adj = F.relu(adj)
|
262 |
+
r = F.softmax(adj,dim=-1)
|
263 |
+
# print(adj.size())
|
264 |
+
# row_sum = adj.sum(-1).unsqueeze(-1)
|
265 |
+
# d_mat = row_sum.expand(adj.size())
|
266 |
+
# r = torch.div(row_sum,d_mat)
|
267 |
+
# r[torch.isnan(r)] = 0
|
268 |
+
|
269 |
+
return r
|
270 |
+
|
271 |
+
|
272 |
+
if __name__ == '__main__':
|
273 |
+
|
274 |
+
graph = torch.randn((7,128))
|
275 |
+
en = GraphConvolution(128,128)
|
276 |
+
a = en.forward(graph)
|
277 |
+
print(a)
|
278 |
+
# a = en.forward(graph,pred)
|
279 |
+
# print(a.size())
|
TryYours-Virtual-Try-On/Graphonomy-master/networks/graph.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle as pkl
|
3 |
+
import networkx as nx
|
4 |
+
import scipy.sparse as sp
|
5 |
+
import torch
|
6 |
+
|
7 |
+
pascal_graph = {0:[0],
|
8 |
+
1:[1, 2],
|
9 |
+
2:[1, 2, 3, 5],
|
10 |
+
3:[2, 3, 4],
|
11 |
+
4:[3, 4],
|
12 |
+
5:[2, 5, 6],
|
13 |
+
6:[5, 6]}
|
14 |
+
|
15 |
+
cihp_graph = {0: [],
|
16 |
+
1: [2, 13],
|
17 |
+
2: [1, 13],
|
18 |
+
3: [14, 15],
|
19 |
+
4: [13],
|
20 |
+
5: [6, 7, 9, 10, 11, 12, 14, 15],
|
21 |
+
6: [5, 7, 10, 11, 14, 15, 16, 17],
|
22 |
+
7: [5, 6, 9, 10, 11, 12, 14, 15],
|
23 |
+
8: [16, 17, 18, 19],
|
24 |
+
9: [5, 7, 10, 16, 17, 18, 19],
|
25 |
+
10:[5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17],
|
26 |
+
11:[5, 6, 7, 10, 13],
|
27 |
+
12:[5, 7, 10, 16, 17],
|
28 |
+
13:[1, 2, 4, 10, 11],
|
29 |
+
14:[3, 5, 6, 7, 10],
|
30 |
+
15:[3, 5, 6, 7, 10],
|
31 |
+
16:[6, 8, 9, 10, 12, 18],
|
32 |
+
17:[6, 8, 9, 10, 12, 19],
|
33 |
+
18:[8, 9, 16],
|
34 |
+
19:[8, 9, 17]}
|
35 |
+
|
36 |
+
atr_graph = {0: [],
|
37 |
+
1: [2, 11],
|
38 |
+
2: [1, 11],
|
39 |
+
3: [11],
|
40 |
+
4: [5, 6, 7, 11, 14, 15, 17],
|
41 |
+
5: [4, 6, 7, 8, 12, 13],
|
42 |
+
6: [4,5,7,8,9,10,12,13],
|
43 |
+
7: [4,11,12,13,14,15],
|
44 |
+
8: [5,6],
|
45 |
+
9: [6, 12],
|
46 |
+
10:[6, 13],
|
47 |
+
11:[1,2,3,4,7,14,15,17],
|
48 |
+
12:[5,6,7,9],
|
49 |
+
13:[5,6,7,10],
|
50 |
+
14:[4,7,11,16],
|
51 |
+
15:[4,7,11,16],
|
52 |
+
16:[14,15],
|
53 |
+
17:[4,11],
|
54 |
+
}
|
55 |
+
|
56 |
+
cihp2pascal_adj = np.array([[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
|
57 |
+
[0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
|
58 |
+
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
59 |
+
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
60 |
+
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
61 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
62 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
|
63 |
+
|
64 |
+
cihp2pascal_nlp_adj = \
|
65 |
+
np.array([[ 1., 0.35333052, 0.32727194, 0.17418084, 0.18757584,
|
66 |
+
0.40608522, 0.37503981, 0.35448462, 0.22598555, 0.23893579,
|
67 |
+
0.33064262, 0.28923404, 0.27986573, 0.4211553 , 0.36915778,
|
68 |
+
0.41377746, 0.32485771, 0.37248222, 0.36865639, 0.41500332],
|
69 |
+
[ 0.39615879, 0.46201529, 0.52321467, 0.30826114, 0.25669527,
|
70 |
+
0.54747773, 0.3670523 , 0.3901983 , 0.27519473, 0.3433325 ,
|
71 |
+
0.52728509, 0.32771333, 0.34819325, 0.63882953, 0.68042925,
|
72 |
+
0.69368576, 0.63395791, 0.65344337, 0.59538781, 0.6071375 ],
|
73 |
+
[ 0.16373166, 0.21663339, 0.3053872 , 0.28377612, 0.1372435 ,
|
74 |
+
0.4448808 , 0.29479995, 0.31092595, 0.22703953, 0.33983576,
|
75 |
+
0.75778818, 0.2619818 , 0.37069392, 0.35184867, 0.49877512,
|
76 |
+
0.49979437, 0.51853277, 0.52517541, 0.32517741, 0.32377309],
|
77 |
+
[ 0.32687232, 0.38482461, 0.37693463, 0.41610834, 0.20415749,
|
78 |
+
0.76749079, 0.35139853, 0.3787411 , 0.28411737, 0.35155421,
|
79 |
+
0.58792618, 0.31141718, 0.40585111, 0.51189218, 0.82042737,
|
80 |
+
0.8342413 , 0.70732188, 0.72752501, 0.60327325, 0.61431337],
|
81 |
+
[ 0.34069369, 0.34817292, 0.37525998, 0.36497069, 0.17841617,
|
82 |
+
0.69746208, 0.31731463, 0.34628951, 0.25167277, 0.32072379,
|
83 |
+
0.56711286, 0.24894776, 0.37000453, 0.52600859, 0.82483993,
|
84 |
+
0.84966274, 0.7033991 , 0.73449378, 0.56649608, 0.58888791],
|
85 |
+
[ 0.28477487, 0.35139564, 0.42742352, 0.41664321, 0.20004676,
|
86 |
+
0.78566833, 0.42237487, 0.41048549, 0.37933812, 0.46542516,
|
87 |
+
0.62444759, 0.3274493 , 0.49466009, 0.49314658, 0.71244233,
|
88 |
+
0.71497003, 0.8234787 , 0.83566589, 0.62597135, 0.62626812],
|
89 |
+
[ 0.3011378 , 0.31775977, 0.42922647, 0.36896257, 0.17597556,
|
90 |
+
0.72214655, 0.39162804, 0.38137872, 0.34980296, 0.43818419,
|
91 |
+
0.60879174, 0.26762545, 0.46271161, 0.51150476, 0.72318109,
|
92 |
+
0.73678399, 0.82620388, 0.84942166, 0.5943811 , 0.60607602]])
|
93 |
+
|
94 |
+
pascal2atr_nlp_adj = \
|
95 |
+
np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522,
|
96 |
+
0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639,
|
97 |
+
0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778,
|
98 |
+
0.41377746, 0.32006291, 0.28923404],
|
99 |
+
[ 0.39615879, 0.46201529, 0.52321467, 0.25669527, 0.54747773,
|
100 |
+
0.34819325, 0.3433325 , 0.26603942, 0.45162929, 0.59538781,
|
101 |
+
0.6071375 , 0.63882953, 0.63395791, 0.65344337, 0.68042925,
|
102 |
+
0.69368576, 0.44354613, 0.32771333],
|
103 |
+
[ 0.16373166, 0.21663339, 0.3053872 , 0.1372435 , 0.4448808 ,
|
104 |
+
0.37069392, 0.33983576, 0.26563416, 0.35443504, 0.32517741,
|
105 |
+
0.32377309, 0.35184867, 0.51853277, 0.52517541, 0.49877512,
|
106 |
+
0.49979437, 0.21750868, 0.2619818 ],
|
107 |
+
[ 0.32687232, 0.38482461, 0.37693463, 0.20415749, 0.76749079,
|
108 |
+
0.40585111, 0.35155421, 0.28271333, 0.52684576, 0.60327325,
|
109 |
+
0.61431337, 0.51189218, 0.70732188, 0.72752501, 0.82042737,
|
110 |
+
0.8342413 , 0.40137029, 0.31141718],
|
111 |
+
[ 0.34069369, 0.34817292, 0.37525998, 0.17841617, 0.69746208,
|
112 |
+
0.37000453, 0.32072379, 0.27268885, 0.47426719, 0.56649608,
|
113 |
+
0.58888791, 0.52600859, 0.7033991 , 0.73449378, 0.82483993,
|
114 |
+
0.84966274, 0.37830796, 0.24894776],
|
115 |
+
[ 0.28477487, 0.35139564, 0.42742352, 0.20004676, 0.78566833,
|
116 |
+
0.49466009, 0.46542516, 0.32662614, 0.55780359, 0.62597135,
|
117 |
+
0.62626812, 0.49314658, 0.8234787 , 0.83566589, 0.71244233,
|
118 |
+
0.71497003, 0.41223219, 0.3274493 ],
|
119 |
+
[ 0.3011378 , 0.31775977, 0.42922647, 0.17597556, 0.72214655,
|
120 |
+
0.46271161, 0.43818419, 0.3192333 , 0.50979216, 0.5943811 ,
|
121 |
+
0.60607602, 0.51150476, 0.82620388, 0.84942166, 0.72318109,
|
122 |
+
0.73678399, 0.39259827, 0.26762545]])
|
123 |
+
|
124 |
+
cihp2atr_nlp_adj = np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522,
|
125 |
+
0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639,
|
126 |
+
0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778,
|
127 |
+
0.41377746, 0.32006291, 0.28923404],
|
128 |
+
[ 0.35333052, 1. , 0.39206695, 0.42143438, 0.4736689 ,
|
129 |
+
0.47139544, 0.51999208, 0.38354847, 0.45628529, 0.46514124,
|
130 |
+
0.50083501, 0.4310595 , 0.39371443, 0.4319752 , 0.42938598,
|
131 |
+
0.46384034, 0.44833757, 0.6153155 ],
|
132 |
+
[ 0.32727194, 0.39206695, 1. , 0.32836702, 0.52603065,
|
133 |
+
0.39543695, 0.3622627 , 0.43575346, 0.33866223, 0.45202552,
|
134 |
+
0.48421 , 0.53669903, 0.47266611, 0.50925436, 0.42286557,
|
135 |
+
0.45403656, 0.37221304, 0.40999322],
|
136 |
+
[ 0.17418084, 0.46892601, 0.25774838, 0.31816231, 0.39330317,
|
137 |
+
0.34218382, 0.48253904, 0.22084125, 0.41335728, 0.52437572,
|
138 |
+
0.5191713 , 0.33576117, 0.44230914, 0.44250678, 0.44330833,
|
139 |
+
0.43887264, 0.50693611, 0.39278795],
|
140 |
+
[ 0.18757584, 0.42143438, 0.32836702, 1. , 0.35030067,
|
141 |
+
0.30110947, 0.41055555, 0.34338879, 0.34336307, 0.37704433,
|
142 |
+
0.38810141, 0.34702081, 0.24171562, 0.25433078, 0.24696241,
|
143 |
+
0.2570884 , 0.4465962 , 0.45263213],
|
144 |
+
[ 0.40608522, 0.4736689 , 0.52603065, 0.35030067, 1. ,
|
145 |
+
0.54372584, 0.58300258, 0.56674191, 0.555266 , 0.66599594,
|
146 |
+
0.68567555, 0.55716359, 0.62997328, 0.65638548, 0.61219615,
|
147 |
+
0.63183318, 0.54464151, 0.44293752],
|
148 |
+
[ 0.37503981, 0.50675565, 0.4761106 , 0.37561813, 0.60419403,
|
149 |
+
0.77912403, 0.64595517, 0.85939662, 0.46037144, 0.52348817,
|
150 |
+
0.55875094, 0.37741886, 0.455671 , 0.49434392, 0.38479954,
|
151 |
+
0.41804074, 0.47285709, 0.57236283],
|
152 |
+
[ 0.35448462, 0.50576632, 0.51030446, 0.35841033, 0.55106903,
|
153 |
+
0.50257274, 0.52591451, 0.4283053 , 0.39991808, 0.42327211,
|
154 |
+
0.42853819, 0.42071825, 0.41240559, 0.42259136, 0.38125352,
|
155 |
+
0.3868255 , 0.47604934, 0.51811717],
|
156 |
+
[ 0.22598555, 0.5053299 , 0.36301185, 0.38002282, 0.49700941,
|
157 |
+
0.45625243, 0.62876479, 0.4112051 , 0.33944371, 0.48322639,
|
158 |
+
0.50318714, 0.29207815, 0.38801966, 0.41119094, 0.29199072,
|
159 |
+
0.31021029, 0.41594871, 0.54961962],
|
160 |
+
[ 0.23893579, 0.51999208, 0.3622627 , 0.41055555, 0.58300258,
|
161 |
+
0.68874251, 1. , 0.56977937, 0.49918447, 0.48484363,
|
162 |
+
0.51615925, 0.41222306, 0.49535971, 0.53134951, 0.3807616 ,
|
163 |
+
0.41050298, 0.48675801, 0.51112664],
|
164 |
+
[ 0.33064262, 0.306412 , 0.60679935, 0.25592294, 0.58738706,
|
165 |
+
0.40379627, 0.39679161, 0.33618385, 0.39235148, 0.45474013,
|
166 |
+
0.4648476 , 0.59306762, 0.58976007, 0.60778661, 0.55400397,
|
167 |
+
0.56551297, 0.3698029 , 0.33860535],
|
168 |
+
[ 0.28923404, 0.6153155 , 0.40999322, 0.45263213, 0.44293752,
|
169 |
+
0.60359359, 0.51112664, 0.46578181, 0.45656936, 0.38142307,
|
170 |
+
0.38525582, 0.33327223, 0.35360175, 0.36156453, 0.3384992 ,
|
171 |
+
0.34261229, 0.49297863, 1. ],
|
172 |
+
[ 0.27986573, 0.47139544, 0.39543695, 0.30110947, 0.54372584,
|
173 |
+
1. , 0.68874251, 0.67765588, 0.48690078, 0.44010641,
|
174 |
+
0.44921156, 0.32321099, 0.48311542, 0.4982002 , 0.39378102,
|
175 |
+
0.40297733, 0.45309735, 0.60359359],
|
176 |
+
[ 0.4211553 , 0.4310595 , 0.53669903, 0.34702081, 0.55716359,
|
177 |
+
0.32321099, 0.41222306, 0.25721705, 0.36633509, 0.5397475 ,
|
178 |
+
0.56429928, 1. , 0.55796926, 0.58842844, 0.57930828,
|
179 |
+
0.60410597, 0.41615326, 0.33327223],
|
180 |
+
[ 0.36915778, 0.42938598, 0.42286557, 0.24696241, 0.61219615,
|
181 |
+
0.39378102, 0.3807616 , 0.28089866, 0.48450394, 0.77400821,
|
182 |
+
0.68813814, 0.57930828, 0.8856886 , 0.81673412, 1. ,
|
183 |
+
0.92279623, 0.46969152, 0.3384992 ],
|
184 |
+
[ 0.41377746, 0.46384034, 0.45403656, 0.2570884 , 0.63183318,
|
185 |
+
0.40297733, 0.41050298, 0.332879 , 0.48799542, 0.69231828,
|
186 |
+
0.77015091, 0.60410597, 0.79788484, 0.88232104, 0.92279623,
|
187 |
+
1. , 0.45685017, 0.34261229],
|
188 |
+
[ 0.32485771, 0.39371443, 0.47266611, 0.24171562, 0.62997328,
|
189 |
+
0.48311542, 0.49535971, 0.32477932, 0.51486622, 0.79353556,
|
190 |
+
0.69768738, 0.55796926, 1. , 0.92373745, 0.8856886 ,
|
191 |
+
0.79788484, 0.47883134, 0.35360175],
|
192 |
+
[ 0.37248222, 0.4319752 , 0.50925436, 0.25433078, 0.65638548,
|
193 |
+
0.4982002 , 0.53134951, 0.38057074, 0.52403969, 0.72035243,
|
194 |
+
0.78711147, 0.58842844, 0.92373745, 1. , 0.81673412,
|
195 |
+
0.88232104, 0.47109935, 0.36156453],
|
196 |
+
[ 0.36865639, 0.46514124, 0.45202552, 0.37704433, 0.66599594,
|
197 |
+
0.44010641, 0.48484363, 0.39636574, 0.50175258, 1. ,
|
198 |
+
0.91320249, 0.5397475 , 0.79353556, 0.72035243, 0.77400821,
|
199 |
+
0.69231828, 0.59087008, 0.38142307],
|
200 |
+
[ 0.41500332, 0.50083501, 0.48421, 0.38810141, 0.68567555,
|
201 |
+
0.44921156, 0.51615925, 0.45156472, 0.50438158, 0.91320249,
|
202 |
+
1., 0.56429928, 0.69768738, 0.78711147, 0.68813814,
|
203 |
+
0.77015091, 0.57698754, 0.38525582]])
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
def normalize_adj(adj):
|
208 |
+
"""Symmetrically normalize adjacency matrix."""
|
209 |
+
adj = sp.coo_matrix(adj)
|
210 |
+
rowsum = np.array(adj.sum(1))
|
211 |
+
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
|
212 |
+
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
|
213 |
+
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
214 |
+
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
|
215 |
+
|
216 |
+
def preprocess_adj(adj):
|
217 |
+
"""Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
|
218 |
+
adj = nx.adjacency_matrix(nx.from_dict_of_lists(adj)) # return a adjacency matrix of adj ( type is numpy)
|
219 |
+
adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) #
|
220 |
+
# return sparse_to_tuple(adj_normalized)
|
221 |
+
return adj_normalized.todense()
|
222 |
+
|
223 |
+
def row_norm(inputs):
|
224 |
+
outputs = []
|
225 |
+
for x in inputs:
|
226 |
+
xsum = x.sum()
|
227 |
+
x = x / xsum
|
228 |
+
outputs.append(x)
|
229 |
+
return outputs
|
230 |
+
|
231 |
+
|
232 |
+
def normalize_adj_torch(adj):
|
233 |
+
# print(adj.size())
|
234 |
+
if len(adj.size()) == 4:
|
235 |
+
new_r = torch.zeros(adj.size()).type_as(adj)
|
236 |
+
for i in range(adj.size(1)):
|
237 |
+
adj_item = adj[0,i]
|
238 |
+
rowsum = adj_item.sum(1)
|
239 |
+
d_inv_sqrt = rowsum.pow_(-0.5)
|
240 |
+
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
|
241 |
+
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
|
242 |
+
r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt)
|
243 |
+
new_r[0,i,...] = r
|
244 |
+
return new_r
|
245 |
+
rowsum = adj.sum(1)
|
246 |
+
d_inv_sqrt = rowsum.pow_(-0.5)
|
247 |
+
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
|
248 |
+
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
|
249 |
+
r = torch.matmul(torch.matmul(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)
|
250 |
+
return r
|
251 |
+
|
252 |
+
# def row_norm(adj):
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
if __name__ == '__main__':
|
258 |
+
a= row_norm(cihp2pascal_adj)
|
259 |
+
print(a)
|
260 |
+
print(cihp2pascal_adj)
|
261 |
+
# print(a.shape)
|