wizzseen commited on
Commit
8a6df40
·
verified ·
1 Parent(s): 452382e

Upload 948 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. TryYours-Virtual-Try-On/.gitignore +5 -0
  3. TryYours-Virtual-Try-On/Demo.ipynb +267 -0
  4. TryYours-Virtual-Try-On/Graphonomy-master/LICENSE +21 -0
  5. TryYours-Virtual-Try-On/Graphonomy-master/README.md +124 -0
  6. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__init__.py +0 -0
  7. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-310.pyc +0 -0
  8. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-39.pyc +0 -0
  9. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-310.pyc +0 -0
  10. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-39.pyc +0 -0
  11. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/atr.py +109 -0
  12. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp.py +107 -0
  13. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp_pascal_atr.py +219 -0
  14. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/custom_transforms.py +491 -0
  15. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_atr.py +8 -0
  16. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_cihp.py +8 -0
  17. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_pascal.py +8 -0
  18. TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/pascal.py +106 -0
  19. TryYours-Virtual-Try-On/Graphonomy-master/eval_cihp.sh +5 -0
  20. TryYours-Virtual-Try-On/Graphonomy-master/eval_pascal.sh +5 -0
  21. TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/.ipynb_checkpoints/inference-checkpoint.py +203 -0
  22. TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/inference.py +206 -0
  23. TryYours-Virtual-Try-On/Graphonomy-master/exp/test/__init__.py +3 -0
  24. TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_cihp2pascal.py +268 -0
  25. TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_pascal2cihp.py +268 -0
  26. TryYours-Virtual-Try-On/Graphonomy-master/exp/test/test_from_disk.py +65 -0
  27. TryYours-Virtual-Try-On/Graphonomy-master/exp/transfer/train_cihp_from_pascal.py +331 -0
  28. TryYours-Virtual-Try-On/Graphonomy-master/exp/universal/pascal_atr_cihp_uni.py +493 -0
  29. TryYours-Virtual-Try-On/Graphonomy-master/inference.sh +1 -0
  30. TryYours-Virtual-Try-On/Graphonomy-master/networks/__init__.py +3 -0
  31. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-310.pyc +0 -0
  32. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-39.pyc +0 -0
  33. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-310.pyc +0 -0
  34. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-39.pyc +0 -0
  35. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-310.pyc +0 -0
  36. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-39.pyc +0 -0
  37. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-310.pyc +0 -0
  38. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-39.pyc +0 -0
  39. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-310.pyc +0 -0
  40. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-39.pyc +0 -0
  41. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-310.pyc +0 -0
  42. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-39.pyc +0 -0
  43. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-310.pyc +0 -0
  44. TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-39.pyc +0 -0
  45. TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception.py +684 -0
  46. TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_synBN.py +596 -0
  47. TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_transfer.py +1003 -0
  48. TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_universal.py +1077 -0
  49. TryYours-Virtual-Try-On/Graphonomy-master/networks/gcn.py +279 -0
  50. TryYours-Virtual-Try-On/Graphonomy-master/networks/graph.py +261 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ TryYours-Virtual-Try-On/HR-VITON-main/figures/fig.jpg filter=lfs diff=lfs merge=lfs -text
37
+ TryYours-Virtual-Try-On/HR-VITON-main/Output/00001_00_00001_00.png filter=lfs diff=lfs merge=lfs -text
38
+ TryYours-Virtual-Try-On/static/finalimg.png filter=lfs diff=lfs merge=lfs -text
39
+ TryYours-Virtual-Try-On/static/origin_web.jpg filter=lfs diff=lfs merge=lfs -text
40
+ TryYours-Virtual-Try-On/TryYours_presentation_kr.pdf filter=lfs diff=lfs merge=lfs -text
TryYours-Virtual-Try-On/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # don't upload large pth files
2
+
3
+ HR-VITON-main/gen.pth
4
+ HR-VITON-main/mtviton.pth
5
+ Graphonomy-master/inference.pth
TryYours-Virtual-Try-On/Demo.ipynb ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "803b037a-183e-46e4-bfde-a144dc779cf0",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2022-12-29T07:03:01.120935Z",
10
+ "iopub.status.busy": "2022-12-29T07:03:01.120639Z",
11
+ "iopub.status.idle": "2022-12-29T07:03:03.538100Z",
12
+ "shell.execute_reply": "2022-12-29T07:03:03.537115Z",
13
+ "shell.execute_reply.started": "2022-12-29T07:03:01.120878Z"
14
+ }
15
+ },
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "Collecting tensorboardX\n",
22
+ " Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)\n",
23
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m125.4/125.4 kB\u001b[0m \u001b[31m23.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
24
+ "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from tensorboardX) (1.23.1)\n",
25
+ "Requirement already satisfied: protobuf<=3.20.1,>=3.8.0 in /usr/local/lib/python3.9/dist-packages (from tensorboardX) (3.19.4)\n",
26
+ "Installing collected packages: tensorboardX\n",
27
+ "Successfully installed tensorboardX-2.5.1\n",
28
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
29
+ "\u001b[0m"
30
+ ]
31
+ }
32
+ ],
33
+ "source": [
34
+ "!pip install tensorboardX av torchgeometry flask flask-ngrok iglovikov_helper_functions cloths_segmentation albumentations"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 8,
40
+ "id": "756d9b98-41d4-4f4f-85bc-51c77b99507c",
41
+ "metadata": {
42
+ "execution": {
43
+ "iopub.execute_input": "2022-12-29T07:07:46.679818Z",
44
+ "iopub.status.busy": "2022-12-29T07:07:46.679527Z",
45
+ "iopub.status.idle": "2022-12-29T07:10:34.279721Z",
46
+ "shell.execute_reply": "2022-12-29T07:10:34.278988Z",
47
+ "shell.execute_reply.started": "2022-12-29T07:07:46.679796Z"
48
+ }
49
+ },
50
+ "outputs": [
51
+ {
52
+ "name": "stdout",
53
+ "output_type": "stream",
54
+ "text": [
55
+ "Collecting git+https://github.com/facebookresearch/detectron2.git\n",
56
+ " Cloning https://github.com/facebookresearch/detectron2.git to /tmp/pip-req-build-llo0z_rd\n",
57
+ " Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/detectron2.git /tmp/pip-req-build-llo0z_rd\n",
58
+ " Resolved https://github.com/facebookresearch/detectron2.git to commit 857d5de21a7789d1bba46694cf608b1cb2ea128a\n",
59
+ " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
60
+ "\u001b[?25hRequirement already satisfied: Pillow>=7.1 in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (9.2.0)\n",
61
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (3.5.2)\n",
62
+ "Collecting pycocotools>=2.0.2\n",
63
+ " Downloading pycocotools-2.0.6.tar.gz (24 kB)\n",
64
+ " Installing build dependencies ... \u001b[?25ldone\n",
65
+ "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
66
+ "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
67
+ "\u001b[?25hRequirement already satisfied: termcolor>=1.1 in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (1.1.0)\n",
68
+ "Collecting yacs>=0.1.8\n",
69
+ " Downloading yacs-0.1.8-py3-none-any.whl (14 kB)\n",
70
+ "Collecting tabulate\n",
71
+ " Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)\n",
72
+ "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (2.1.0)\n",
73
+ "Requirement already satisfied: tqdm>4.29.0 in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (4.64.0)\n",
74
+ "Requirement already satisfied: tensorboard in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (2.9.1)\n",
75
+ "Collecting fvcore<0.1.6,>=0.1.5\n",
76
+ " Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)\n",
77
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.2/50.2 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
78
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
79
+ "\u001b[?25hCollecting iopath<0.1.10,>=0.1.7\n",
80
+ " Downloading iopath-0.1.9-py3-none-any.whl (27 kB)\n",
81
+ "Collecting omegaconf>=2.1\n",
82
+ " Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)\n",
83
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.5/79.5 kB\u001b[0m \u001b[31m20.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
84
+ "\u001b[?25hCollecting hydra-core>=1.1\n",
85
+ " Downloading hydra_core-1.3.1-py3-none-any.whl (154 kB)\n",
86
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.1/154.1 kB\u001b[0m \u001b[31m31.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
87
+ "\u001b[?25hCollecting black\n",
88
+ " Downloading black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n",
89
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m82.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
90
+ "\u001b[?25hCollecting timm\n",
91
+ " Downloading timm-0.6.12-py3-none-any.whl (549 kB)\n",
92
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m549.1/549.1 kB\u001b[0m \u001b[31m51.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
93
+ "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from detectron2==0.6) (21.3)\n",
94
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from fvcore<0.1.6,>=0.1.5->detectron2==0.6) (1.23.1)\n",
95
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from fvcore<0.1.6,>=0.1.5->detectron2==0.6) (5.4.1)\n",
96
+ "Collecting antlr4-python3-runtime==4.9.*\n",
97
+ " Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)\n",
98
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.0/117.0 kB\u001b[0m \u001b[31m28.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
99
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
100
+ "\u001b[?25hCollecting portalocker\n",
101
+ " Downloading portalocker-2.6.0-py2.py3-none-any.whl (15 kB)\n",
102
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (0.11.0)\n",
103
+ "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (3.0.9)\n",
104
+ "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (2.8.2)\n",
105
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (4.34.4)\n",
106
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->detectron2==0.6) (1.4.3)\n",
107
+ "Collecting mypy-extensions>=0.4.3\n",
108
+ " Downloading mypy_extensions-0.4.3-py2.py3-none-any.whl (4.5 kB)\n",
109
+ "Collecting pathspec>=0.9.0\n",
110
+ " Downloading pathspec-0.10.3-py3-none-any.whl (29 kB)\n",
111
+ "Collecting platformdirs>=2\n",
112
+ " Downloading platformdirs-2.6.2-py3-none-any.whl (14 kB)\n",
113
+ "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from black->detectron2==0.6) (8.1.3)\n",
114
+ "Collecting tomli>=1.1.0\n",
115
+ " Downloading tomli-2.0.1-py3-none-any.whl (12 kB)\n",
116
+ "Requirement already satisfied: typing-extensions>=3.10.0.0 in /usr/local/lib/python3.9/dist-packages (from black->detectron2==0.6) (4.3.0)\n",
117
+ "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (0.6.1)\n",
118
+ "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (2.2.2)\n",
119
+ "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (1.8.1)\n",
120
+ "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (3.3.7)\n",
121
+ "Requirement already satisfied: protobuf<3.20,>=3.9.2 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (3.19.4)\n",
122
+ "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (63.1.0)\n",
123
+ "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (2.28.1)\n",
124
+ "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (0.4.6)\n",
125
+ "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (2.9.0)\n",
126
+ "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (1.47.0)\n",
127
+ "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (0.35.1)\n",
128
+ "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.9/dist-packages (from tensorboard->detectron2==0.6) (1.1.0)\n",
129
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.9/dist-packages (from timm->detectron2==0.6) (0.13.0+cu116)\n",
130
+ "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.9/dist-packages (from timm->detectron2==0.6) (0.8.1)\n",
131
+ "Requirement already satisfied: torch>=1.7 in /usr/local/lib/python3.9/dist-packages (from timm->detectron2==0.6) (1.12.0+cu116)\n",
132
+ "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.9/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (0.2.8)\n",
133
+ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (5.2.0)\n",
134
+ "Requirement already satisfied: six>=1.9.0 in /usr/lib/python3/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (1.14.0)\n",
135
+ "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/lib/python3/dist-packages (from google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (4.0)\n",
136
+ "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.9/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard->detectron2==0.6) (1.3.1)\n",
137
+ "Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.9/dist-packages (from markdown>=2.6.8->tensorboard->detectron2==0.6) (4.12.0)\n",
138
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (1.26.10)\n",
139
+ "Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.9/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (2.1.0)\n",
140
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (2.8)\n",
141
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests<3,>=2.21.0->tensorboard->detectron2==0.6) (2019.11.28)\n",
142
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.9/dist-packages (from werkzeug>=1.0.1->tensorboard->detectron2==0.6) (2.1.1)\n",
143
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub->timm->detectron2==0.6) (3.7.1)\n",
144
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard->detectron2==0.6) (3.8.1)\n",
145
+ "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.9/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->detectron2==0.6) (0.4.8)\n",
146
+ "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.9/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard->detectron2==0.6) (3.2.0)\n",
147
+ "Building wheels for collected packages: detectron2, fvcore, antlr4-python3-runtime, pycocotools\n",
148
+ " Building wheel for detectron2 (setup.py) ... \u001b[?25ldone\n",
149
+ "\u001b[?25h Created wheel for detectron2: filename=detectron2-0.6-cp39-cp39-linux_x86_64.whl size=5889223 sha256=ed1f701e1bc42d870d0d6fafd338e2a8e663f2847c8960619ecf38a4db0338ad\n",
150
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-bzzpgw82/wheels/59/b4/83/84bfca751fa4dcc59998468be8688eb50e97408a83af171d42\n",
151
+ " Building wheel for fvcore (setup.py) ... \u001b[?25ldone\n",
152
+ "\u001b[?25h Created wheel for fvcore: filename=fvcore-0.1.5.post20221221-py3-none-any.whl size=61406 sha256=669c4db5ec6509578bb39829693a8e5438fde1d0116ec668be64bf0261e0eb9f\n",
153
+ " Stored in directory: /root/.cache/pip/wheels/83/42/02/66178d16e5c44dc26d309931834956baeda371956e86fbd876\n",
154
+ " Building wheel for antlr4-python3-runtime (setup.py) ... \u001b[?25ldone\n",
155
+ "\u001b[?25h Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144554 sha256=ddfc41c7e0307f85e68b773551e1be87484cbe4ac68ba04bd4e2dc501dfb6c95\n",
156
+ " Stored in directory: /root/.cache/pip/wheels/23/cf/80/f3efa822e6ab23277902ee9165fe772eeb1dfb8014f359020a\n",
157
+ " Building wheel for pycocotools (pyproject.toml) ... \u001b[?25ldone\n",
158
+ "\u001b[?25h Created wheel for pycocotools: filename=pycocotools-2.0.6-cp39-cp39-linux_x86_64.whl size=400228 sha256=a356c0450d2148ac9d590e3a6adc7b6e913d050be851a84ec830219795200de9\n",
159
+ " Stored in directory: /root/.cache/pip/wheels/2f/58/25/e78f1f766e904a9071266661d20d0bc6644df86bcd160aba11\n",
160
+ "Successfully built detectron2 fvcore antlr4-python3-runtime pycocotools\n",
161
+ "Installing collected packages: mypy-extensions, antlr4-python3-runtime, yacs, tomli, tabulate, portalocker, platformdirs, pathspec, omegaconf, iopath, hydra-core, black, timm, pycocotools, fvcore, detectron2\n",
162
+ "Successfully installed antlr4-python3-runtime-4.9.3 black-22.12.0 detectron2-0.6 fvcore-0.1.5.post20221221 hydra-core-1.3.1 iopath-0.1.9 mypy-extensions-0.4.3 omegaconf-2.3.0 pathspec-0.10.3 platformdirs-2.6.2 portalocker-2.6.0 pycocotools-2.0.6 tabulate-0.9.0 timm-0.6.12 tomli-2.0.1 yacs-0.1.8\n",
163
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
164
+ "\u001b[0m"
165
+ ]
166
+ }
167
+ ],
168
+ "source": [
169
+ "!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 6,
175
+ "id": "ba6c8b48-7243-4602-8c62-862f7a8f00cb",
176
+ "metadata": {
177
+ "execution": {
178
+ "iopub.execute_input": "2022-12-29T07:03:17.515189Z",
179
+ "iopub.status.busy": "2022-12-29T07:03:17.514793Z",
180
+ "iopub.status.idle": "2022-12-29T07:03:21.008057Z",
181
+ "shell.execute_reply": "2022-12-29T07:03:21.006831Z",
182
+ "shell.execute_reply.started": "2022-12-29T07:03:17.515153Z"
183
+ }
184
+ },
185
+ "outputs": [
186
+ {
187
+ "name": "stdout",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "Collecting pyngrok==4.1.1\n",
191
+ " Downloading pyngrok-4.1.1.tar.gz (18 kB)\n",
192
+ " Preparing metadata (setup.py) ... \u001b[?25ldone\n",
193
+ "\u001b[?25hRequirement already satisfied: future in /usr/lib/python3/dist-packages (from pyngrok==4.1.1) (0.18.2)\n",
194
+ "Requirement already satisfied: PyYAML in /usr/local/lib/python3.9/dist-packages (from pyngrok==4.1.1) (5.4.1)\n",
195
+ "Building wheels for collected packages: pyngrok\n",
196
+ " Building wheel for pyngrok (setup.py) ... \u001b[?25ldone\n",
197
+ "\u001b[?25h Created wheel for pyngrok: filename=pyngrok-4.1.1-py3-none-any.whl size=15965 sha256=3669af38b11fcc66e95001ad9cdd77ad1da63d79252e42fe38331b06321ef59d\n",
198
+ " Stored in directory: /root/.cache/pip/wheels/89/2d/c2/abe6bcfde6bce368c00ecd73310c11edb672c3eda09a090cfa\n",
199
+ "Successfully built pyngrok\n",
200
+ "Installing collected packages: pyngrok\n",
201
+ "Successfully installed pyngrok-4.1.1\n",
202
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
203
+ "\u001b[0m"
204
+ ]
205
+ }
206
+ ],
207
+ "source": [
208
+ "!pip install pyngrok==4.1.1"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 7,
214
+ "id": "e163eab2-841c-4372-8f75-3d60cbbe247a",
215
+ "metadata": {
216
+ "execution": {
217
+ "iopub.execute_input": "2022-12-29T07:03:21.010722Z",
218
+ "iopub.status.busy": "2022-12-29T07:03:21.009782Z",
219
+ "iopub.status.idle": "2022-12-29T07:03:22.584809Z",
220
+ "shell.execute_reply": "2022-12-29T07:03:22.583859Z",
221
+ "shell.execute_reply.started": "2022-12-29T07:03:21.010677Z"
222
+ }
223
+ },
224
+ "outputs": [
225
+ {
226
+ "name": "stdout",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml \n"
230
+ ]
231
+ }
232
+ ],
233
+ "source": [
234
+ "!ngrok authtoken yourtoken"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "id": "70b46c0a-7f56-4cb0-a84c-836caf911ca4",
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": []
244
+ }
245
+ ],
246
+ "metadata": {
247
+ "kernelspec": {
248
+ "display_name": "Python 3 (ipykernel)",
249
+ "language": "python",
250
+ "name": "python3"
251
+ },
252
+ "language_info": {
253
+ "codemirror_mode": {
254
+ "name": "ipython",
255
+ "version": 3
256
+ },
257
+ "file_extension": ".py",
258
+ "mimetype": "text/x-python",
259
+ "name": "python",
260
+ "nbconvert_exporter": "python",
261
+ "pygments_lexer": "ipython3",
262
+ "version": "3.9.13"
263
+ }
264
+ },
265
+ "nbformat": 4,
266
+ "nbformat_minor": 5
267
+ }
TryYours-Virtual-Try-On/Graphonomy-master/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2016 Vladimir Nekrasov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
TryYours-Virtual-Try-On/Graphonomy-master/README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Graphonomy: Universal Human Parsing via Graph Transfer Learning
2
+
3
+ This repository contains the code for the paper:
4
+
5
+ [**Graphonomy: Universal Human Parsing via Graph Transfer Learning**](https://arxiv.org/abs/1904.04536)
6
+ ,Ke Gong, Yiming Gao, Xiaodan Liang, Xiaohui Shen, Meng Wang, Liang Lin.
7
+
8
+
9
+ # Environment and installation
10
+ + Pytorch = 0.4.0
11
+ + torchvision
12
+ + scipy
13
+ + tensorboardX
14
+ + numpy
15
+ + opencv-python
16
+ + matplotlib
17
+ + networkx
18
+
19
+ you can install above package by using `pip install -r requirements.txt`
20
+
21
+ # Getting Started
22
+ ### Data Preparation
23
+ + You need to download the human parsing dataset, prepare the images and store in `/data/datasets/dataset_name/`.
24
+ We recommend to symlink the path to the dataets to `/data/dataset/` as follows
25
+
26
+ ```
27
+ # symlink the Pascal-Person-Part dataset for example
28
+ ln -s /path_to_Pascal_Person_Part/* data/datasets/pascal/
29
+ ```
30
+ + The file structure should look like:
31
+ ```
32
+ /Graphonomy
33
+ /data
34
+ /datasets
35
+ /pascal
36
+ /JPEGImages
37
+ /list
38
+ /SegmentationPart
39
+ /CIHP_4w
40
+ /Images
41
+ /lists
42
+ ...
43
+ ```
44
+ + The datasets (CIHP & ATR) are available at [google drive](https://drive.google.com/drive/folders/0BzvH3bSnp3E9ZW9paE9kdkJtM3M?usp=sharing)
45
+ and [baidu drive](http://pan.baidu.com/s/1nvqmZBN).
46
+ And you also need to download the label with flipped.
47
+ Download [cihp_flipped](https://drive.google.com/file/d/1aaJyQH-hlZEAsA7iH-mYeK1zLfQi8E2j/view?usp=sharing), unzip and store in `data/datasets/CIHP_4w/`.
48
+ Download [atr_flip](https://drive.google.com/file/d/1iR8Tn69IbDSM7gq_GG-_s11HCnhPkyG3/view?usp=sharing), unzip and store in `data/datasets/ATR/`.
49
+
50
+ ### Inference
51
+ We provide a simply script to get the visualization result on the CIHP dataset using [trained](https://drive.google.com/file/d/1O9YD4kHgs3w2DUcWxtHiEFyWjCBeS_Vc/view?usp=sharing)
52
+ models as follows :
53
+ ```shell
54
+ # Example of inference
55
+ python exp/inference/inference.py \
56
+ --loadmodel /path_to_inference_model \
57
+ --img_path ./img/messi.jpg \
58
+ --output_path ./img/ \
59
+ --output_name /output_file_name
60
+ ```
61
+
62
+ ### Training
63
+ #### Transfer learning
64
+ 1. Download the Pascal pretrained model(available soon).
65
+ 2. Run the `sh train_transfer_cihp.sh`.
66
+ 3. The results and models are saved in exp/transfer/run/.
67
+ 4. Evaluation and visualization script is eval_cihp.sh. You only need to change the attribute of `--loadmodel` before you run it.
68
+
69
+ #### Universal training
70
+ 1. Download the [pretrained](https://drive.google.com/file/d/18WiffKnxaJo50sCC9zroNyHjcnTxGCbk/view?usp=sharing) model and store in /data/pretrained_model/.
71
+ 2. Run the `sh train_universal.sh`.
72
+ 3. The results and models are saved in exp/universal/run/.
73
+
74
+ ### Testing
75
+ If you want to evaluate the performance of a pre-trained model on PASCAL-Person-Part or CIHP val/test set,
76
+ simply run the script: `sh eval_cihp/pascal.sh`.
77
+ Specify the specific model. And we provide the final model that you can download and store it in /data/pretrained_model/.
78
+
79
+ ### Models
80
+ **Pascal-Person-Part trained model**
81
+
82
+ |Model|Google Cloud|Baidu Yun|
83
+ |--------|--------------|-----------|
84
+ |Graphonomy(CIHP)| [Download](https://drive.google.com/file/d/1E_V_gVDWfAJFPfe-LLu2RQaYQMdhjv9h/view?usp=sharing)| Available soon|
85
+
86
+ **CIHP trained model**
87
+
88
+ |Model|Google Cloud|Baidu Yun|
89
+ |--------|--------------|-----------|
90
+ |Graphonomy(PASCAL)| [Download](https://drive.google.com/file/d/1eUe18HoH05p0yFUd_sN6GXdTj82aW0m9/view?usp=sharing)| Available soon|
91
+
92
+ **Universal trained model**
93
+
94
+ |Model|Google Cloud|Baidu Yun|
95
+ |--------|--------------|-----------|
96
+ |Universal| [Download](https://drive.google.com/file/d/1sWJ54lCBFnzCNz5RTCGQmkVovkY9x8_D/view?usp=sharing)|Available soon|
97
+
98
+ ### Todo:
99
+ - [ ] release pretrained and trained models
100
+ - [ ] update universal eval code&script
101
+
102
+ # Citation
103
+
104
+ ```
105
+ @inproceedings{Gong2019Graphonomy,
106
+ author = {Ke Gong and Yiming Gao and Xiaodan Liang and Xiaohui Shen and Meng Wang and Liang Lin},
107
+ title = {Graphonomy: Universal Human Parsing via Graph Transfer Learning},
108
+ booktitle = {CVPR},
109
+ year = {2019},
110
+ }
111
+
112
+ ```
113
+
114
+ # Contact
115
+ if you have any questions about this repo, please feel free to contact
116
117
+
118
+ ##
119
+
120
+ ## Related work
121
+ + Self-supervised Structure-sensitive Learning [SSL](https://github.com/Engineering-Course/LIP_SSL)
122
+ + Joint Body Parsing & Pose Estimation Network [JPPNet](https://github.com/Engineering-Course/LIP_JPPNet)
123
+ + Instance-level Human Parsing via Part Grouping Network [PGN](https://github.com/Engineering-Course/CIHP_PGN)
124
+ + Graphonomy: Universal Image Parsing via Graph Reasoning and Transfer [paper](https://arxiv.org/abs/2101.10620) [code](https://github.com/Gaoyiminggithub/Graphonomy-Panoptic)
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__init__.py ADDED
File without changes
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (177 Bytes). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (153 Bytes). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/__pycache__/custom_transforms.cpython-39.pyc ADDED
Binary file (15 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/atr.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ from PIL import Image
4
+ from torch.utils.data import Dataset
5
+ from .mypath_atr import Path
6
+ import random
7
+ from PIL import ImageFile
8
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
9
+
10
+ class VOCSegmentation(Dataset):
11
+ """
12
+ ATR dataset
13
+ """
14
+
15
+ def __init__(self,
16
+ base_dir=Path.db_root_dir('atr'),
17
+ split='train',
18
+ transform=None,
19
+ flip=False,
20
+ ):
21
+ """
22
+ :param base_dir: path to ATR dataset directory
23
+ :param split: train/val
24
+ :param transform: transform to apply
25
+ """
26
+ super(VOCSegmentation).__init__()
27
+ self._flip_flag = flip
28
+
29
+ self._base_dir = base_dir
30
+ self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
31
+ self._cat_dir = os.path.join(self._base_dir, 'SegmentationClassAug')
32
+ self._flip_dir = os.path.join(self._base_dir,'SegmentationClassAug_rev')
33
+
34
+ if isinstance(split, str):
35
+ self.split = [split]
36
+ else:
37
+ split.sort()
38
+ self.split = split
39
+
40
+ self.transform = transform
41
+
42
+ _splits_dir = os.path.join(self._base_dir, 'list')
43
+
44
+ self.im_ids = []
45
+ self.images = []
46
+ self.categories = []
47
+ self.flip_categories = []
48
+
49
+ for splt in self.split:
50
+ with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
51
+ lines = f.read().splitlines()
52
+
53
+ for ii, line in enumerate(lines):
54
+
55
+ _image = os.path.join(self._image_dir, line+'.jpg' )
56
+ _cat = os.path.join(self._cat_dir, line +'.png')
57
+ _flip = os.path.join(self._flip_dir,line + '.png')
58
+ # print(self._image_dir,_image)
59
+ assert os.path.isfile(_image)
60
+ # print(_cat)
61
+ assert os.path.isfile(_cat)
62
+ assert os.path.isfile(_flip)
63
+ self.im_ids.append(line)
64
+ self.images.append(_image)
65
+ self.categories.append(_cat)
66
+ self.flip_categories.append(_flip)
67
+
68
+
69
+ assert (len(self.images) == len(self.categories))
70
+ assert len(self.flip_categories) == len(self.categories)
71
+
72
+ # Display stats
73
+ print('Number of images in {}: {:d}'.format(split, len(self.images)))
74
+
75
+ def __len__(self):
76
+ return len(self.images)
77
+
78
+
79
+ def __getitem__(self, index):
80
+ _img, _target= self._make_img_gt_point_pair(index)
81
+ sample = {'image': _img, 'label': _target}
82
+
83
+ if self.transform is not None:
84
+ sample = self.transform(sample)
85
+
86
+ return sample
87
+
88
+ def _make_img_gt_point_pair(self, index):
89
+ # Read Image and Target
90
+ # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
91
+ # _target = np.array(Image.open(self.categories[index])).astype(np.float32)
92
+
93
+ _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
94
+ if self._flip_flag:
95
+ if random.random() < 0.5:
96
+ _target = Image.open(self.flip_categories[index])
97
+ _img = _img.transpose(Image.FLIP_LEFT_RIGHT)
98
+ else:
99
+ _target = Image.open(self.categories[index])
100
+ else:
101
+ _target = Image.open(self.categories[index])
102
+
103
+ return _img, _target
104
+
105
+ def __str__(self):
106
+ return 'ATR(split=' + str(self.split) + ')'
107
+
108
+
109
+
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ from PIL import Image
4
+ from torch.utils.data import Dataset
5
+ from .mypath_cihp import Path
6
+ import random
7
+
8
+ class VOCSegmentation(Dataset):
9
+ """
10
+ CIHP dataset
11
+ """
12
+
13
+ def __init__(self,
14
+ base_dir=Path.db_root_dir('cihp'),
15
+ split='train',
16
+ transform=None,
17
+ flip=False,
18
+ ):
19
+ """
20
+ :param base_dir: path to CIHP dataset directory
21
+ :param split: train/val/test
22
+ :param transform: transform to apply
23
+ """
24
+ super(VOCSegmentation).__init__()
25
+ self._flip_flag = flip
26
+
27
+ self._base_dir = base_dir
28
+ self._image_dir = os.path.join(self._base_dir, 'Images')
29
+ self._cat_dir = os.path.join(self._base_dir, 'Category_ids')
30
+ self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids')
31
+
32
+ if isinstance(split, str):
33
+ self.split = [split]
34
+ else:
35
+ split.sort()
36
+ self.split = split
37
+
38
+ self.transform = transform
39
+
40
+ _splits_dir = os.path.join(self._base_dir, 'lists')
41
+
42
+ self.im_ids = []
43
+ self.images = []
44
+ self.categories = []
45
+ self.flip_categories = []
46
+
47
+ for splt in self.split:
48
+ with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
49
+ lines = f.read().splitlines()
50
+
51
+ for ii, line in enumerate(lines):
52
+
53
+ _image = os.path.join(self._image_dir, line+'.jpg' )
54
+ _cat = os.path.join(self._cat_dir, line +'.png')
55
+ _flip = os.path.join(self._flip_dir,line + '.png')
56
+ # print(self._image_dir,_image)
57
+ assert os.path.isfile(_image)
58
+ # print(_cat)
59
+ assert os.path.isfile(_cat)
60
+ assert os.path.isfile(_flip)
61
+ self.im_ids.append(line)
62
+ self.images.append(_image)
63
+ self.categories.append(_cat)
64
+ self.flip_categories.append(_flip)
65
+
66
+
67
+ assert (len(self.images) == len(self.categories))
68
+ assert len(self.flip_categories) == len(self.categories)
69
+
70
+ # Display stats
71
+ print('Number of images in {}: {:d}'.format(split, len(self.images)))
72
+
73
+ def __len__(self):
74
+ return len(self.images)
75
+
76
+
77
+ def __getitem__(self, index):
78
+ _img, _target= self._make_img_gt_point_pair(index)
79
+ sample = {'image': _img, 'label': _target}
80
+
81
+ if self.transform is not None:
82
+ sample = self.transform(sample)
83
+
84
+ return sample
85
+
86
+ def _make_img_gt_point_pair(self, index):
87
+ # Read Image and Target
88
+ # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
89
+ # _target = np.array(Image.open(self.categories[index])).astype(np.float32)
90
+
91
+ _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
92
+ if self._flip_flag:
93
+ if random.random() < 0.5:
94
+ _target = Image.open(self.flip_categories[index])
95
+ _img = _img.transpose(Image.FLIP_LEFT_RIGHT)
96
+ else:
97
+ _target = Image.open(self.categories[index])
98
+ else:
99
+ _target = Image.open(self.categories[index])
100
+
101
+ return _img, _target
102
+
103
+ def __str__(self):
104
+ return 'CIHP(split=' + str(self.split) + ')'
105
+
106
+
107
+
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/cihp_pascal_atr.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+ from .mypath_cihp import Path
7
+ from .mypath_pascal import Path as PP
8
+ from .mypath_atr import Path as PA
9
+ import random
10
+ from PIL import ImageFile
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+
13
+ class VOCSegmentation(Dataset):
14
+ """
15
+ Pascal dataset
16
+ """
17
+
18
+ def __init__(self,
19
+ cihp_dir=Path.db_root_dir('cihp'),
20
+ split='train',
21
+ transform=None,
22
+ flip=False,
23
+ pascal_dir = PP.db_root_dir('pascal'),
24
+ atr_dir = PA.db_root_dir('atr'),
25
+ ):
26
+ """
27
+ :param cihp_dir: path to CIHP dataset directory
28
+ :param pascal_dir: path to PASCAL dataset directory
29
+ :param atr_dir: path to ATR dataset directory
30
+ :param split: train/val
31
+ :param transform: transform to apply
32
+ """
33
+ super(VOCSegmentation).__init__()
34
+ ## for cihp
35
+ self._flip_flag = flip
36
+ self._base_dir = cihp_dir
37
+ self._image_dir = os.path.join(self._base_dir, 'Images')
38
+ self._cat_dir = os.path.join(self._base_dir, 'Category_ids')
39
+ self._flip_dir = os.path.join(self._base_dir,'Category_rev_ids')
40
+ ## for Pascal
41
+ self._base_dir_pascal = pascal_dir
42
+ self._image_dir_pascal = os.path.join(self._base_dir_pascal, 'JPEGImages')
43
+ self._cat_dir_pascal = os.path.join(self._base_dir_pascal, 'SegmentationPart')
44
+ # self._flip_dir_pascal = os.path.join(self._base_dir_pascal, 'Category_rev_ids')
45
+ ## for atr
46
+ self._base_dir_atr = atr_dir
47
+ self._image_dir_atr = os.path.join(self._base_dir_atr, 'JPEGImages')
48
+ self._cat_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug')
49
+ self._flip_dir_atr = os.path.join(self._base_dir_atr, 'SegmentationClassAug_rev')
50
+
51
+ if isinstance(split, str):
52
+ self.split = [split]
53
+ else:
54
+ split.sort()
55
+ self.split = split
56
+
57
+ self.transform = transform
58
+
59
+ _splits_dir = os.path.join(self._base_dir, 'lists')
60
+ _splits_dir_pascal = os.path.join(self._base_dir_pascal, 'list')
61
+ _splits_dir_atr = os.path.join(self._base_dir_atr, 'list')
62
+
63
+ self.im_ids = []
64
+ self.images = []
65
+ self.categories = []
66
+ self.flip_categories = []
67
+ self.datasets_lbl = []
68
+
69
+ # num
70
+ self.num_cihp = 0
71
+ self.num_pascal = 0
72
+ self.num_atr = 0
73
+ # for cihp is 0
74
+ for splt in self.split:
75
+ with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
76
+ lines = f.read().splitlines()
77
+ self.num_cihp += len(lines)
78
+ for ii, line in enumerate(lines):
79
+
80
+ _image = os.path.join(self._image_dir, line+'.jpg' )
81
+ _cat = os.path.join(self._cat_dir, line +'.png')
82
+ _flip = os.path.join(self._flip_dir,line + '.png')
83
+ # print(self._image_dir,_image)
84
+ assert os.path.isfile(_image)
85
+ # print(_cat)
86
+ assert os.path.isfile(_cat)
87
+ assert os.path.isfile(_flip)
88
+ self.im_ids.append(line)
89
+ self.images.append(_image)
90
+ self.categories.append(_cat)
91
+ self.flip_categories.append(_flip)
92
+ self.datasets_lbl.append(0)
93
+
94
+ # for pascal is 1
95
+ for splt in self.split:
96
+ if splt == 'test':
97
+ splt='val'
98
+ with open(os.path.join(os.path.join(_splits_dir_pascal, splt + '_id.txt')), "r") as f:
99
+ lines = f.read().splitlines()
100
+ self.num_pascal += len(lines)
101
+ for ii, line in enumerate(lines):
102
+
103
+ _image = os.path.join(self._image_dir_pascal, line+'.jpg' )
104
+ _cat = os.path.join(self._cat_dir_pascal, line +'.png')
105
+ # _flip = os.path.join(self._flip_dir,line + '.png')
106
+ # print(self._image_dir,_image)
107
+ assert os.path.isfile(_image)
108
+ # print(_cat)
109
+ assert os.path.isfile(_cat)
110
+ # assert os.path.isfile(_flip)
111
+ self.im_ids.append(line)
112
+ self.images.append(_image)
113
+ self.categories.append(_cat)
114
+ self.flip_categories.append([])
115
+ self.datasets_lbl.append(1)
116
+
117
+ # for atr is 2
118
+ for splt in self.split:
119
+ with open(os.path.join(os.path.join(_splits_dir_atr, splt + '_id.txt')), "r") as f:
120
+ lines = f.read().splitlines()
121
+ self.num_atr += len(lines)
122
+ for ii, line in enumerate(lines):
123
+ _image = os.path.join(self._image_dir_atr, line + '.jpg')
124
+ _cat = os.path.join(self._cat_dir_atr, line + '.png')
125
+ _flip = os.path.join(self._flip_dir_atr, line + '.png')
126
+ # print(self._image_dir,_image)
127
+ assert os.path.isfile(_image)
128
+ # print(_cat)
129
+ assert os.path.isfile(_cat)
130
+ assert os.path.isfile(_flip)
131
+ self.im_ids.append(line)
132
+ self.images.append(_image)
133
+ self.categories.append(_cat)
134
+ self.flip_categories.append(_flip)
135
+ self.datasets_lbl.append(2)
136
+
137
+ assert (len(self.images) == len(self.categories))
138
+ # assert len(self.flip_categories) == len(self.categories)
139
+
140
+ # Display stats
141
+ print('Number of images in {}: {:d}'.format(split, len(self.images)))
142
+
143
+ def __len__(self):
144
+ return len(self.images)
145
+
146
+ def get_class_num(self):
147
+ return self.num_cihp,self.num_pascal,self.num_atr
148
+
149
+
150
+
151
+ def __getitem__(self, index):
152
+ _img, _target,_lbl= self._make_img_gt_point_pair(index)
153
+ sample = {'image': _img, 'label': _target,}
154
+
155
+ if self.transform is not None:
156
+ sample = self.transform(sample)
157
+ sample['pascal'] = _lbl
158
+ return sample
159
+
160
+ def _make_img_gt_point_pair(self, index):
161
+ # Read Image and Target
162
+ # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
163
+ # _target = np.array(Image.open(self.categories[index])).astype(np.float32)
164
+
165
+ _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
166
+ type_lbl = self.datasets_lbl[index]
167
+ if self._flip_flag:
168
+ if random.random() < 0.5 :
169
+ # _target = Image.open(self.flip_categories[index])
170
+ _img = _img.transpose(Image.FLIP_LEFT_RIGHT)
171
+ if type_lbl == 0 or type_lbl == 2:
172
+ _target = Image.open(self.flip_categories[index])
173
+ else:
174
+ _target = Image.open(self.categories[index])
175
+ _target = _target.transpose(Image.FLIP_LEFT_RIGHT)
176
+ else:
177
+ _target = Image.open(self.categories[index])
178
+ else:
179
+ _target = Image.open(self.categories[index])
180
+
181
+ return _img, _target,type_lbl
182
+
183
+ def __str__(self):
184
+ return 'datasets(split=' + str(self.split) + ')'
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+ if __name__ == '__main__':
198
+ from dataloaders import custom_transforms as tr
199
+ from dataloaders.utils import decode_segmap
200
+ from torch.utils.data import DataLoader
201
+ from torchvision import transforms
202
+ import matplotlib.pyplot as plt
203
+
204
+ composed_transforms_tr = transforms.Compose([
205
+ # tr.RandomHorizontalFlip(),
206
+ tr.RandomSized_new(512),
207
+ tr.RandomRotate(15),
208
+ tr.ToTensor_()])
209
+
210
+
211
+
212
+ voc_train = VOCSegmentation(split='train',
213
+ transform=composed_transforms_tr)
214
+
215
+ dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=1)
216
+
217
+ for ii, sample in enumerate(dataloader):
218
+ if ii >10:
219
+ break
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/custom_transforms.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numbers
4
+ import random
5
+ import numpy as np
6
+
7
+ from PIL import Image, ImageOps
8
+ from torchvision import transforms
9
+
10
+ class RandomCrop(object):
11
+ def __init__(self, size, padding=0):
12
+ if isinstance(size, numbers.Number):
13
+ self.size = (int(size), int(size))
14
+ else:
15
+ self.size = size # h, w
16
+ self.padding = padding
17
+
18
+ def __call__(self, sample):
19
+ img, mask = sample['image'], sample['label']
20
+
21
+ if self.padding > 0:
22
+ img = ImageOps.expand(img, border=self.padding, fill=0)
23
+ mask = ImageOps.expand(mask, border=self.padding, fill=0)
24
+
25
+ assert img.size == mask.size
26
+ w, h = img.size
27
+ th, tw = self.size # target size
28
+ if w == tw and h == th:
29
+ return {'image': img,
30
+ 'label': mask}
31
+ if w < tw or h < th:
32
+ img = img.resize((tw, th), Image.BILINEAR)
33
+ mask = mask.resize((tw, th), Image.NEAREST)
34
+ return {'image': img,
35
+ 'label': mask}
36
+
37
+ x1 = random.randint(0, w - tw)
38
+ y1 = random.randint(0, h - th)
39
+ img = img.crop((x1, y1, x1 + tw, y1 + th))
40
+ mask = mask.crop((x1, y1, x1 + tw, y1 + th))
41
+
42
+ return {'image': img,
43
+ 'label': mask}
44
+
45
+ class RandomCrop_new(object):
46
+ def __init__(self, size, padding=0):
47
+ if isinstance(size, numbers.Number):
48
+ self.size = (int(size), int(size))
49
+ else:
50
+ self.size = size # h, w
51
+ self.padding = padding
52
+
53
+ def __call__(self, sample):
54
+ img, mask = sample['image'], sample['label']
55
+
56
+ if self.padding > 0:
57
+ img = ImageOps.expand(img, border=self.padding, fill=0)
58
+ mask = ImageOps.expand(mask, border=self.padding, fill=0)
59
+
60
+ assert img.size == mask.size
61
+ w, h = img.size
62
+ th, tw = self.size # target size
63
+ if w == tw and h == th:
64
+ return {'image': img,
65
+ 'label': mask}
66
+
67
+ new_img = Image.new('RGB',(tw,th),'black') # size is w x h; and 'white' is 255
68
+ new_mask = Image.new('L',(tw,th),'white') # same above
69
+
70
+ # if w > tw or h > th
71
+ x1 = y1 = 0
72
+ if w > tw:
73
+ x1 = random.randint(0,w - tw)
74
+ if h > th:
75
+ y1 = random.randint(0,h - th)
76
+ # crop
77
+ img = img.crop((x1,y1, x1 + tw, y1 + th))
78
+ mask = mask.crop((x1,y1, x1 + tw, y1 + th))
79
+ new_img.paste(img,(0,0))
80
+ new_mask.paste(mask,(0,0))
81
+
82
+ # x1 = random.randint(0, w - tw)
83
+ # y1 = random.randint(0, h - th)
84
+ # img = img.crop((x1, y1, x1 + tw, y1 + th))
85
+ # mask = mask.crop((x1, y1, x1 + tw, y1 + th))
86
+
87
+ return {'image': new_img,
88
+ 'label': new_mask}
89
+
90
+ class Paste(object):
91
+ def __init__(self, size,):
92
+ if isinstance(size, numbers.Number):
93
+ self.size = (int(size), int(size))
94
+ else:
95
+ self.size = size # h, w
96
+
97
+ def __call__(self, sample):
98
+ img, mask = sample['image'], sample['label']
99
+
100
+ assert img.size == mask.size
101
+ w, h = img.size
102
+ th, tw = self.size # target size
103
+ assert (w <=tw) and (h <= th)
104
+ if w == tw and h == th:
105
+ return {'image': img,
106
+ 'label': mask}
107
+
108
+ new_img = Image.new('RGB',(tw,th),'black') # size is w x h; and 'white' is 255
109
+ new_mask = Image.new('L',(tw,th),'white') # same above
110
+
111
+ new_img.paste(img,(0,0))
112
+ new_mask.paste(mask,(0,0))
113
+
114
+ return {'image': new_img,
115
+ 'label': new_mask}
116
+
117
+ class CenterCrop(object):
118
+ def __init__(self, size):
119
+ if isinstance(size, numbers.Number):
120
+ self.size = (int(size), int(size))
121
+ else:
122
+ self.size = size
123
+
124
+ def __call__(self, sample):
125
+ img = sample['image']
126
+ mask = sample['label']
127
+ assert img.size == mask.size
128
+ w, h = img.size
129
+ th, tw = self.size
130
+ x1 = int(round((w - tw) / 2.))
131
+ y1 = int(round((h - th) / 2.))
132
+ img = img.crop((x1, y1, x1 + tw, y1 + th))
133
+ mask = mask.crop((x1, y1, x1 + tw, y1 + th))
134
+
135
+ return {'image': img,
136
+ 'label': mask}
137
+
138
+ class RandomHorizontalFlip(object):
139
+ def __call__(self, sample):
140
+ img = sample['image']
141
+ mask = sample['label']
142
+ if random.random() < 0.5:
143
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
144
+ mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
145
+
146
+ return {'image': img,
147
+ 'label': mask}
148
+
149
+ class HorizontalFlip(object):
150
+ def __call__(self, sample):
151
+ img = sample['image']
152
+ mask = sample['label']
153
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
154
+ mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
155
+
156
+ return {'image': img,
157
+ 'label': mask}
158
+
159
+ class HorizontalFlip_only_img(object):
160
+ def __call__(self, sample):
161
+ img = sample['image']
162
+ mask = sample['label']
163
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
164
+ # mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
165
+
166
+ return {'image': img,
167
+ 'label': mask}
168
+
169
+ class RandomHorizontalFlip_cihp(object):
170
+ def __call__(self, sample):
171
+ img = sample['image']
172
+ mask = sample['label']
173
+ if random.random() < 0.5:
174
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
175
+ # mask = Image.open()
176
+
177
+ return {'image': img,
178
+ 'label': mask}
179
+
180
+ class Normalize(object):
181
+ """Normalize a tensor image with mean and standard deviation.
182
+ Args:
183
+ mean (tuple): means for each channel.
184
+ std (tuple): standard deviations for each channel.
185
+ """
186
+ def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
187
+ self.mean = mean
188
+ self.std = std
189
+
190
+ def __call__(self, sample):
191
+ img = np.array(sample['image']).astype(np.float32)
192
+ mask = np.array(sample['label']).astype(np.float32)
193
+ img /= 255.0
194
+ img -= self.mean
195
+ img /= self.std
196
+
197
+ return {'image': img,
198
+ 'label': mask}
199
+
200
+ class Normalize_255(object):
201
+ """Normalize a tensor image with mean and standard deviation. tf use 255.
202
+ Args:
203
+ mean (tuple): means for each channel.
204
+ std (tuple): standard deviations for each channel.
205
+ """
206
+ def __init__(self, mean=(123.15, 115.90, 103.06), std=(1., 1., 1.)):
207
+ self.mean = mean
208
+ self.std = std
209
+
210
+ def __call__(self, sample):
211
+ img = np.array(sample['image']).astype(np.float32)
212
+ mask = np.array(sample['label']).astype(np.float32)
213
+ # img = 255.0
214
+ img -= self.mean
215
+ img /= self.std
216
+ img = img
217
+ img = img[[0,3,2,1],...]
218
+ return {'image': img,
219
+ 'label': mask}
220
+
221
+ class Normalize_xception_tf(object):
222
+ # def __init__(self):
223
+ # self.rgb2bgr =
224
+
225
+ def __call__(self, sample):
226
+ img = np.array(sample['image']).astype(np.float32)
227
+ mask = np.array(sample['label']).astype(np.float32)
228
+ img = (img*2.0)/255.0 - 1
229
+ # print(img.shape)
230
+ # img = img[[0,3,2,1],...]
231
+ return {'image': img,
232
+ 'label': mask}
233
+
234
+ class Normalize_xception_tf_only_img(object):
235
+ # def __init__(self):
236
+ # self.rgb2bgr =
237
+
238
+ def __call__(self, sample):
239
+ img = np.array(sample['image']).astype(np.float32)
240
+ # mask = np.array(sample['label']).astype(np.float32)
241
+ img = (img*2.0)/255.0 - 1
242
+ # print(img.shape)
243
+ # img = img[[0,3,2,1],...]
244
+ return {'image': img,
245
+ 'label': sample['label']}
246
+
247
+ class Normalize_cityscapes(object):
248
+ """Normalize a tensor image with mean and standard deviation.
249
+ Args:
250
+ mean (tuple): means for each channel.
251
+ std (tuple): standard deviations for each channel.
252
+ """
253
+ def __init__(self, mean=(0., 0., 0.)):
254
+ self.mean = mean
255
+
256
+ def __call__(self, sample):
257
+ img = np.array(sample['image']).astype(np.float32)
258
+ mask = np.array(sample['label']).astype(np.float32)
259
+ img -= self.mean
260
+ img /= 255.0
261
+
262
+ return {'image': img,
263
+ 'label': mask}
264
+
265
+ class ToTensor_(object):
266
+ """Convert ndarrays in sample to Tensors."""
267
+ def __init__(self):
268
+ self.rgb2bgr = transforms.Lambda(lambda x:x[[2,1,0],...])
269
+
270
+ def __call__(self, sample):
271
+ # swap color axis because
272
+ # numpy image: H x W x C
273
+ # torch image: C X H X W
274
+ img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
275
+ mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
276
+ # mask[mask == 255] = 0
277
+
278
+ img = torch.from_numpy(img).float()
279
+ img = self.rgb2bgr(img)
280
+ mask = torch.from_numpy(mask).float()
281
+
282
+
283
+ return {'image': img,
284
+ 'label': mask}
285
+
286
+ class ToTensor_only_img(object):
287
+ """Convert ndarrays in sample to Tensors."""
288
+ def __init__(self):
289
+ self.rgb2bgr = transforms.Lambda(lambda x:x[[2,1,0],...])
290
+
291
+ def __call__(self, sample):
292
+ # swap color axis because
293
+ # numpy image: H x W x C
294
+ # torch image: C X H X W
295
+ img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
296
+ # mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
297
+ # mask[mask == 255] = 0
298
+
299
+ img = torch.from_numpy(img).float()
300
+ img = self.rgb2bgr(img)
301
+ # mask = torch.from_numpy(mask).float()
302
+
303
+
304
+ return {'image': img,
305
+ 'label': sample['label']}
306
+
307
+ class FixedResize(object):
308
+ def __init__(self, size):
309
+ self.size = tuple(reversed(size)) # size: (h, w)
310
+
311
+ def __call__(self, sample):
312
+ img = sample['image']
313
+ mask = sample['label']
314
+
315
+ assert img.size == mask.size
316
+
317
+ img = img.resize(self.size, Image.BILINEAR)
318
+ mask = mask.resize(self.size, Image.NEAREST)
319
+
320
+ return {'image': img,
321
+ 'label': mask}
322
+
323
+ class Keep_origin_size_Resize(object):
324
+ def __init__(self, max_size, scale=1.0):
325
+ self.size = tuple(reversed(max_size)) # size: (h, w)
326
+ self.scale = scale
327
+ self.paste = Paste(int(max_size[0]*scale))
328
+
329
+ def __call__(self, sample):
330
+ img = sample['image']
331
+ mask = sample['label']
332
+
333
+ assert img.size == mask.size
334
+ h, w = self.size
335
+ h = int(h*self.scale)
336
+ w = int(w*self.scale)
337
+ img = img.resize((h, w), Image.BILINEAR)
338
+ mask = mask.resize((h, w), Image.NEAREST)
339
+
340
+ return self.paste({'image': img,
341
+ 'label': mask})
342
+
343
+ class Scale(object):
344
+ def __init__(self, size):
345
+ if isinstance(size, numbers.Number):
346
+ self.size = (int(size), int(size))
347
+ else:
348
+ self.size = size
349
+
350
+ def __call__(self, sample):
351
+ img = sample['image']
352
+ mask = sample['label']
353
+ assert img.size == mask.size
354
+ w, h = img.size
355
+
356
+ if (w >= h and w == self.size[1]) or (h >= w and h == self.size[0]):
357
+ return {'image': img,
358
+ 'label': mask}
359
+ oh, ow = self.size
360
+ img = img.resize((ow, oh), Image.BILINEAR)
361
+ mask = mask.resize((ow, oh), Image.NEAREST)
362
+
363
+ return {'image': img,
364
+ 'label': mask}
365
+
366
+ class Scale_(object):
367
+ def __init__(self, scale):
368
+ self.scale = scale
369
+
370
+ def __call__(self, sample):
371
+ img = sample['image']
372
+ mask = sample['label']
373
+ assert img.size == mask.size
374
+ w, h = img.size
375
+ ow = int(w*self.scale)
376
+ oh = int(h*self.scale)
377
+ img = img.resize((ow, oh), Image.BILINEAR)
378
+ mask = mask.resize((ow, oh), Image.NEAREST)
379
+
380
+ return {'image': img,
381
+ 'label': mask}
382
+
383
+ class Scale_only_img(object):
384
+ def __init__(self, scale):
385
+ self.scale = scale
386
+
387
+ def __call__(self, sample):
388
+ img = sample['image']
389
+ mask = sample['label']
390
+ # assert img.size == mask.size
391
+ w, h = img.size
392
+ ow = int(w*self.scale)
393
+ oh = int(h*self.scale)
394
+ img = img.resize((ow, oh), Image.BILINEAR)
395
+ # mask = mask.resize((ow, oh), Image.NEAREST)
396
+
397
+ return {'image': img,
398
+ 'label': mask}
399
+
400
+ class RandomSizedCrop(object):
401
+ def __init__(self, size):
402
+ self.size = size
403
+
404
+ def __call__(self, sample):
405
+ img = sample['image']
406
+ mask = sample['label']
407
+ assert img.size == mask.size
408
+ for attempt in range(10):
409
+ area = img.size[0] * img.size[1]
410
+ target_area = random.uniform(0.45, 1.0) * area
411
+ aspect_ratio = random.uniform(0.5, 2)
412
+
413
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
414
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
415
+
416
+ if random.random() < 0.5:
417
+ w, h = h, w
418
+
419
+ if w <= img.size[0] and h <= img.size[1]:
420
+ x1 = random.randint(0, img.size[0] - w)
421
+ y1 = random.randint(0, img.size[1] - h)
422
+
423
+ img = img.crop((x1, y1, x1 + w, y1 + h))
424
+ mask = mask.crop((x1, y1, x1 + w, y1 + h))
425
+ assert (img.size == (w, h))
426
+
427
+ img = img.resize((self.size, self.size), Image.BILINEAR)
428
+ mask = mask.resize((self.size, self.size), Image.NEAREST)
429
+
430
+ return {'image': img,
431
+ 'label': mask}
432
+
433
+ # Fallback
434
+ scale = Scale(self.size)
435
+ crop = CenterCrop(self.size)
436
+ sample = crop(scale(sample))
437
+ return sample
438
+
439
+ class RandomRotate(object):
440
+ def __init__(self, degree):
441
+ self.degree = degree
442
+
443
+ def __call__(self, sample):
444
+ img = sample['image']
445
+ mask = sample['label']
446
+ rotate_degree = random.random() * 2 * self.degree - self.degree
447
+ img = img.rotate(rotate_degree, Image.BILINEAR)
448
+ mask = mask.rotate(rotate_degree, Image.NEAREST)
449
+
450
+ return {'image': img,
451
+ 'label': mask}
452
+
453
+ class RandomSized_new(object):
454
+ '''what we use is this class to aug'''
455
+ def __init__(self, size,scale1=0.5,scale2=2):
456
+ self.size = size
457
+ # self.scale = Scale(self.size)
458
+ self.crop = RandomCrop_new(self.size)
459
+ self.small_scale = scale1
460
+ self.big_scale = scale2
461
+
462
+ def __call__(self, sample):
463
+ img = sample['image']
464
+ mask = sample['label']
465
+ assert img.size == mask.size
466
+
467
+ w = int(random.uniform(self.small_scale, self.big_scale) * img.size[0])
468
+ h = int(random.uniform(self.small_scale, self.big_scale) * img.size[1])
469
+
470
+ img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
471
+ sample = {'image': img, 'label': mask}
472
+ # finish resize
473
+ return self.crop(sample)
474
+ # class Random
475
+
476
+ class RandomScale(object):
477
+ def __init__(self, limit):
478
+ self.limit = limit
479
+
480
+ def __call__(self, sample):
481
+ img = sample['image']
482
+ mask = sample['label']
483
+ assert img.size == mask.size
484
+
485
+ scale = random.uniform(self.limit[0], self.limit[1])
486
+ w = int(scale * img.size[0])
487
+ h = int(scale * img.size[1])
488
+
489
+ img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
490
+
491
+ return {'image': img, 'label': mask}
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_atr.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ class Path(object):
2
+ @staticmethod
3
+ def db_root_dir(database):
4
+ if database == 'atr':
5
+ return './data/datasets/ATR/' # folder that contains atr/.
6
+ else:
7
+ print('Database {} not available.'.format(database))
8
+ raise NotImplementedError
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_cihp.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ class Path(object):
2
+ @staticmethod
3
+ def db_root_dir(database):
4
+ if database == 'cihp':
5
+ return './data/datasets/CIHP_4w/'
6
+ else:
7
+ print('Database {} not available.'.format(database))
8
+ raise NotImplementedError
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/mypath_pascal.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ class Path(object):
2
+ @staticmethod
3
+ def db_root_dir(database):
4
+ if database == 'pascal':
5
+ return './data/datasets/pascal/' # folder that contains pascal/.
6
+ else:
7
+ print('Database {} not available.'.format(database))
8
+ raise NotImplementedError
TryYours-Virtual-Try-On/Graphonomy-master/dataloaders/pascal.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ from PIL import Image
4
+ from torch.utils.data import Dataset
5
+ from .mypath_pascal import Path
6
+
7
+ class VOCSegmentation(Dataset):
8
+ """
9
+ Pascal dataset
10
+ """
11
+
12
+ def __init__(self,
13
+ base_dir=Path.db_root_dir('pascal'),
14
+ split='train',
15
+ transform=None
16
+ ):
17
+ """
18
+ :param base_dir: path to PASCAL dataset directory
19
+ :param split: train/val
20
+ :param transform: transform to apply
21
+ """
22
+ super(VOCSegmentation).__init__()
23
+ self._base_dir = base_dir
24
+ self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
25
+ self._cat_dir = os.path.join(self._base_dir, 'SegmentationPart')
26
+
27
+ if isinstance(split, str):
28
+ self.split = [split]
29
+ else:
30
+ split.sort()
31
+ self.split = split
32
+
33
+ self.transform = transform
34
+
35
+ _splits_dir = os.path.join(self._base_dir, 'list')
36
+
37
+ self.im_ids = []
38
+ self.images = []
39
+ self.categories = []
40
+
41
+ for splt in self.split:
42
+ with open(os.path.join(os.path.join(_splits_dir, splt + '_id.txt')), "r") as f:
43
+ lines = f.read().splitlines()
44
+
45
+ for ii, line in enumerate(lines):
46
+
47
+ _image = os.path.join(self._image_dir, line+'.jpg' )
48
+ _cat = os.path.join(self._cat_dir, line +'.png')
49
+ # print(self._image_dir,_image)
50
+ assert os.path.isfile(_image)
51
+ # print(_cat)
52
+ assert os.path.isfile(_cat)
53
+ self.im_ids.append(line)
54
+ self.images.append(_image)
55
+ self.categories.append(_cat)
56
+
57
+ assert (len(self.images) == len(self.categories))
58
+
59
+ # Display stats
60
+ print('Number of images in {}: {:d}'.format(split, len(self.images)))
61
+
62
+ def __len__(self):
63
+ return len(self.images)
64
+
65
+
66
+ def __getitem__(self, index):
67
+ _img, _target= self._make_img_gt_point_pair(index)
68
+ sample = {'image': _img, 'label': _target}
69
+
70
+ if self.transform is not None:
71
+ sample = self.transform(sample)
72
+
73
+ return sample
74
+
75
+ def _make_img_gt_point_pair(self, index):
76
+ # Read Image and Target
77
+ # _img = np.array(Image.open(self.images[index]).convert('RGB')).astype(np.float32)
78
+ # _target = np.array(Image.open(self.categories[index])).astype(np.float32)
79
+
80
+ _img = Image.open(self.images[index]).convert('RGB') # return is RGB pic
81
+ _target = Image.open(self.categories[index])
82
+
83
+ return _img, _target
84
+
85
+ def __str__(self):
86
+ return 'PASCAL(split=' + str(self.split) + ')'
87
+
88
+ class test_segmentation(VOCSegmentation):
89
+ def __init__(self,base_dir=Path.db_root_dir('pascal'),
90
+ split='train',
91
+ transform=None,
92
+ flip=True):
93
+ super(test_segmentation, self).__init__(base_dir=base_dir,split=split,transform=transform)
94
+ self._flip_flag = flip
95
+
96
+ def __getitem__(self, index):
97
+ _img, _target= self._make_img_gt_point_pair(index)
98
+ sample = {'image': _img, 'label': _target}
99
+
100
+ if self.transform is not None:
101
+ sample = self.transform(sample)
102
+
103
+ return sample
104
+
105
+
106
+
TryYours-Virtual-Try-On/Graphonomy-master/eval_cihp.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python ./exp/test/eval_show_pascal2cihp.py \
2
+ --batch 1 --gpus 1 --classes 20 \
3
+ --gt_path './data/datasets/CIHP_4w/Category_ids' \
4
+ --txt_file './data/datasets/CIHP_4w/lists/val_id.txt' \
5
+ --loadmodel './data/pretrained_model/inference.pth'
TryYours-Virtual-Try-On/Graphonomy-master/eval_pascal.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python ./exp/test/eval_show_pascal.py \
2
+ --batch 1 --gpus 1 --classes 7 \
3
+ --gt_path './data/datasets/pascal/SegmentationPart/' \
4
+ --txt_file './data/datasets/pascal/list/val_id.txt' \
5
+ --loadmodel './cihp2pascal.pth'
TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/.ipynb_checkpoints/inference-checkpoint.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import timeit
3
+ import numpy as np
4
+ from PIL import Image
5
+ from datetime import datetime
6
+ import os
7
+ import sys
8
+ from collections import OrderedDict
9
+ sys.path.append('./')
10
+ # PyTorch includes
11
+ import torch
12
+ from torch.autograd import Variable
13
+ from torchvision import transforms
14
+ import cv2
15
+
16
+
17
+ # Custom includes
18
+ from networks import deeplab_xception_transfer, graph
19
+ from dataloaders import custom_transforms as tr
20
+
21
+ #
22
+ import argparse
23
+ import torch.nn.functional as F
24
+
25
+ label_colours = [(0,0,0)
26
+ , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0)
27
+ , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
28
+
29
+
30
+ def flip(x, dim):
31
+ indices = [slice(None)] * x.dim()
32
+ indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
33
+ dtype=torch.long, device=x.device)
34
+ return x[tuple(indices)]
35
+
36
+ def flip_cihp(tail_list):
37
+ '''
38
+
39
+ :param tail_list: tail_list size is 1 x n_class x h x w
40
+ :return:
41
+ '''
42
+ # tail_list = tail_list[0]
43
+ tail_list_rev = [None] * 20
44
+ for xx in range(14):
45
+ tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
46
+ tail_list_rev[14] = tail_list[15].unsqueeze(0)
47
+ tail_list_rev[15] = tail_list[14].unsqueeze(0)
48
+ tail_list_rev[16] = tail_list[17].unsqueeze(0)
49
+ tail_list_rev[17] = tail_list[16].unsqueeze(0)
50
+ tail_list_rev[18] = tail_list[19].unsqueeze(0)
51
+ tail_list_rev[19] = tail_list[18].unsqueeze(0)
52
+ return torch.cat(tail_list_rev,dim=0)
53
+
54
+
55
+ def decode_labels(mask, num_images=1, num_classes=20):
56
+ """Decode batch of segmentation masks.
57
+
58
+ Args:
59
+ mask: result of inference after taking argmax.
60
+ num_images: number of images to decode from the batch.
61
+ num_classes: number of classes to predict (including background).
62
+
63
+ Returns:
64
+ A batch with num_images RGB images of the same size as the input.
65
+ """
66
+ n, h, w = mask.shape
67
+ assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (
68
+ n, num_images)
69
+ outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
70
+ for i in range(num_images):
71
+ img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
72
+ pixels = img.load()
73
+ for j_, j in enumerate(mask[i, :, :]):
74
+ for k_, k in enumerate(j):
75
+ if k < num_classes:
76
+ pixels[k_, j_] = label_colours[k]
77
+ outputs[i] = np.array(img)
78
+ return outputs
79
+
80
+ def read_img(img_path):
81
+ _img = Image.open(img_path).convert('RGB') # return is RGB pic
82
+ return _img
83
+
84
+ def img_transform(img, transform=None):
85
+ sample = {'image': img, 'label': 0}
86
+
87
+ sample = transform(sample)
88
+ return sample
89
+
90
+ def inference(net, img_path='', output_path='./', output_name='f', use_gpu=True):
91
+ '''
92
+
93
+ :param net:
94
+ :param img_path:
95
+ :param output_path:
96
+ :return:
97
+ '''
98
+ # adj
99
+ adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
100
+ adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3)
101
+
102
+ adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
103
+ adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
104
+
105
+ cihp_adj = graph.preprocess_adj(graph.cihp_graph)
106
+ adj3_ = Variable(torch.from_numpy(cihp_adj).float())
107
+ adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
108
+
109
+ # multi-scale
110
+ scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
111
+ img = read_img(img_path)
112
+ testloader_list = []
113
+ testloader_flip_list = []
114
+ for pv in scale_list:
115
+ composed_transforms_ts = transforms.Compose([
116
+ tr.Scale_only_img(pv),
117
+ tr.Normalize_xception_tf_only_img(),
118
+ tr.ToTensor_only_img()])
119
+
120
+ composed_transforms_ts_flip = transforms.Compose([
121
+ tr.Scale_only_img(pv),
122
+ tr.HorizontalFlip_only_img(),
123
+ tr.Normalize_xception_tf_only_img(),
124
+ tr.ToTensor_only_img()])
125
+
126
+ testloader_list.append(img_transform(img, composed_transforms_ts))
127
+ # print(img_transform(img, composed_transforms_ts))
128
+ testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip))
129
+ # print(testloader_list)
130
+ start_time = timeit.default_timer()
131
+ # One testing epoch
132
+ net.eval()
133
+ # 1 0.5 0.75 1.25 1.5 1.75 ; flip:
134
+
135
+ for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)):
136
+ inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
137
+ inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
138
+ inputs = inputs.unsqueeze(0)
139
+ inputs_f = inputs_f.unsqueeze(0)
140
+ inputs = torch.cat((inputs, inputs_f), dim=0)
141
+ if iii == 0:
142
+ _, _, h, w = inputs.size()
143
+ # assert inputs.size() == inputs_f.size()
144
+
145
+ # Forward pass of the mini-batch
146
+ inputs = Variable(inputs, requires_grad=False)
147
+
148
+ with torch.no_grad():
149
+ if use_gpu >= 0:
150
+ inputs = inputs.cuda()
151
+ # outputs = net.forward(inputs)
152
+ outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
153
+ outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
154
+ outputs = outputs.unsqueeze(0)
155
+
156
+ if iii > 0:
157
+ outputs = F.upsample(outputs, size=(h, w), mode='bilinear', align_corners=True)
158
+ outputs_final = outputs_final + outputs
159
+ else:
160
+ outputs_final = outputs.clone()
161
+ ################ plot pic
162
+ predictions = torch.max(outputs_final, 1)[1]
163
+ results = predictions.cpu().numpy()
164
+ vis_res = decode_labels(results)
165
+
166
+ parsing_im = Image.fromarray(vis_res[0])
167
+ parsing_im.save(output_path+'/{}.png'.format(output_name))
168
+ cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :])
169
+
170
+ end_time = timeit.default_timer()
171
+ print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time))
172
+
173
+ if __name__ == '__main__':
174
+ '''argparse begin'''
175
+ parser = argparse.ArgumentParser()
176
+ # parser.add_argument('--loadmodel',default=None,type=str)
177
+ parser.add_argument('--loadmodel', default='', type=str)
178
+ parser.add_argument('--img_path', default='', type=str)
179
+ parser.add_argument('--output_path', default='', type=str)
180
+ parser.add_argument('--output_name', default='', type=str)
181
+ parser.add_argument('--use_gpu', default=1, type=int)
182
+ opts = parser.parse_args()
183
+
184
+ net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20,
185
+ hidden_layers=128,
186
+ source_classes=7, )
187
+ if not opts.loadmodel == '':
188
+ x = torch.load(opts.loadmodel)
189
+ net.load_source_model(x)
190
+ print('load model:', opts.loadmodel)
191
+ else:
192
+ print('no model load !!!!!!!!')
193
+ raise RuntimeError('No model!!!!')
194
+
195
+ if opts.use_gpu >0 :
196
+ net.cuda()
197
+ use_gpu = True
198
+ else:
199
+ use_gpu = False
200
+ raise RuntimeError('must use the gpu!!!!')
201
+
202
+ inference(net=net, img_path=opts.img_path,output_path=opts.output_path , output_name=opts.output_name, use_gpu=use_gpu)
203
+
TryYours-Virtual-Try-On/Graphonomy-master/exp/inference/inference.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import timeit
3
+ import numpy as np
4
+ from PIL import Image
5
+ from datetime import datetime
6
+ import os
7
+ import sys
8
+ from collections import OrderedDict
9
+ sys.path.append('./')
10
+ # PyTorch includes
11
+ import torch
12
+ from torch.autograd import Variable
13
+ from torchvision import transforms
14
+ import cv2
15
+
16
+
17
+ # Custom includes
18
+ from networks import deeplab_xception_transfer, graph
19
+ from dataloaders import custom_transforms as tr
20
+
21
+ #
22
+ import argparse
23
+ import torch.nn.functional as F
24
+
25
+ import warnings
26
+ warnings.filterwarnings("ignore")
27
+
28
+ label_colours = [(0,0,0)
29
+ , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0)
30
+ , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
31
+
32
+
33
+ def flip(x, dim):
34
+ indices = [slice(None)] * x.dim()
35
+ indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
36
+ dtype=torch.long, device=x.device)
37
+ return x[tuple(indices)]
38
+
39
+ def flip_cihp(tail_list):
40
+ '''
41
+
42
+ :param tail_list: tail_list size is 1 x n_class x h x w
43
+ :return:
44
+ '''
45
+ # tail_list = tail_list[0]
46
+ tail_list_rev = [None] * 20
47
+ for xx in range(14):
48
+ tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
49
+ tail_list_rev[14] = tail_list[15].unsqueeze(0)
50
+ tail_list_rev[15] = tail_list[14].unsqueeze(0)
51
+ tail_list_rev[16] = tail_list[17].unsqueeze(0)
52
+ tail_list_rev[17] = tail_list[16].unsqueeze(0)
53
+ tail_list_rev[18] = tail_list[19].unsqueeze(0)
54
+ tail_list_rev[19] = tail_list[18].unsqueeze(0)
55
+ return torch.cat(tail_list_rev,dim=0)
56
+
57
+
58
+ def decode_labels(mask, num_images=1, num_classes=20):
59
+ """Decode batch of segmentation masks.
60
+
61
+ Args:
62
+ mask: result of inference after taking argmax.
63
+ num_images: number of images to decode from the batch.
64
+ num_classes: number of classes to predict (including background).
65
+
66
+ Returns:
67
+ A batch with num_images RGB images of the same size as the input.
68
+ """
69
+ n, h, w = mask.shape
70
+ assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (
71
+ n, num_images)
72
+ outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
73
+ for i in range(num_images):
74
+ img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
75
+ pixels = img.load()
76
+ for j_, j in enumerate(mask[i, :, :]):
77
+ for k_, k in enumerate(j):
78
+ if k < num_classes:
79
+ pixels[k_, j_] = label_colours[k]
80
+ outputs[i] = np.array(img)
81
+ return outputs
82
+
83
+ def read_img(img_path):
84
+ _img = Image.open(img_path).convert('RGB') # return is RGB pic
85
+ return _img
86
+
87
+ def img_transform(img, transform=None):
88
+ sample = {'image': img, 'label': 0}
89
+
90
+ sample = transform(sample)
91
+ return sample
92
+
93
+ def inference(net, img_path='', output_path='./', output_name='f', use_gpu=True):
94
+ '''
95
+
96
+ :param net:
97
+ :param img_path:
98
+ :param output_path:
99
+ :return:
100
+ '''
101
+ # adj
102
+ adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
103
+ adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3)
104
+
105
+ adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
106
+ adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
107
+
108
+ cihp_adj = graph.preprocess_adj(graph.cihp_graph)
109
+ adj3_ = Variable(torch.from_numpy(cihp_adj).float())
110
+ adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
111
+
112
+ # multi-scale
113
+ scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
114
+ img = read_img(img_path)
115
+ testloader_list = []
116
+ testloader_flip_list = []
117
+ for pv in scale_list:
118
+ composed_transforms_ts = transforms.Compose([
119
+ tr.Scale_only_img(pv),
120
+ tr.Normalize_xception_tf_only_img(),
121
+ tr.ToTensor_only_img()])
122
+
123
+ composed_transforms_ts_flip = transforms.Compose([
124
+ tr.Scale_only_img(pv),
125
+ tr.HorizontalFlip_only_img(),
126
+ tr.Normalize_xception_tf_only_img(),
127
+ tr.ToTensor_only_img()])
128
+
129
+ testloader_list.append(img_transform(img, composed_transforms_ts))
130
+ # print(img_transform(img, composed_transforms_ts))
131
+ testloader_flip_list.append(img_transform(img, composed_transforms_ts_flip))
132
+ # print(testloader_list)
133
+ start_time = timeit.default_timer()
134
+ # One testing epoch
135
+ net.eval()
136
+ # 1 0.5 0.75 1.25 1.5 1.75 ; flip:
137
+
138
+ for iii, sample_batched in enumerate(zip(testloader_list, testloader_flip_list)):
139
+ inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
140
+ inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
141
+ inputs = inputs.unsqueeze(0)
142
+ inputs_f = inputs_f.unsqueeze(0)
143
+ inputs = torch.cat((inputs, inputs_f), dim=0)
144
+ if iii == 0:
145
+ _, _, h, w = inputs.size()
146
+ # assert inputs.size() == inputs_f.size()
147
+
148
+ # Forward pass of the mini-batch
149
+ inputs = Variable(inputs, requires_grad=False)
150
+
151
+ with torch.no_grad():
152
+ if use_gpu >= 0:
153
+ inputs = inputs.cuda()
154
+ # outputs = net.forward(inputs)
155
+ outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
156
+ outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
157
+ outputs = outputs.unsqueeze(0)
158
+
159
+ if iii > 0:
160
+ outputs = F.upsample(outputs, size=(h, w), mode='bilinear', align_corners=True)
161
+ outputs_final = outputs_final + outputs
162
+ else:
163
+ outputs_final = outputs.clone()
164
+ ################ plot pic
165
+ predictions = torch.max(outputs_final, 1)[1]
166
+ results = predictions.cpu().numpy()
167
+ vis_res = decode_labels(results)
168
+
169
+ parsing_im = Image.fromarray(vis_res[0])
170
+ parsing_im.save(output_path+'/{}.png'.format(output_name))
171
+ cv2.imwrite(output_path+'/{}_gray.png'.format(output_name), results[0, :, :])
172
+
173
+ end_time = timeit.default_timer()
174
+ print('time used for the multi-scale image inference' + ' is :' + str(end_time - start_time))
175
+
176
+ if __name__ == '__main__':
177
+ '''argparse begin'''
178
+ parser = argparse.ArgumentParser()
179
+ # parser.add_argument('--loadmodel',default=None,type=str)
180
+ parser.add_argument('--loadmodel', default='', type=str)
181
+ parser.add_argument('--img_path', default='', type=str)
182
+ parser.add_argument('--output_path', default='', type=str)
183
+ parser.add_argument('--output_name', default='', type=str)
184
+ parser.add_argument('--use_gpu', default=1, type=int)
185
+ opts = parser.parse_args()
186
+
187
+ net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20,
188
+ hidden_layers=128,
189
+ source_classes=7, )
190
+ if not opts.loadmodel == '':
191
+ x = torch.load(opts.loadmodel)
192
+ net.load_source_model(x)
193
+ print('load model:', opts.loadmodel)
194
+ else:
195
+ print('no model load !!!!!!!!')
196
+ raise RuntimeError('No model!!!!')
197
+
198
+ if opts.use_gpu >0 :
199
+ net.cuda()
200
+ use_gpu = True
201
+ else:
202
+ use_gpu = False
203
+ raise RuntimeError('must use the gpu!!!!')
204
+
205
+ inference(net=net, img_path=opts.img_path,output_path=opts.output_path , output_name=opts.output_name, use_gpu=use_gpu)
206
+
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .test_from_disk import eval_
2
+
3
+ __all__ = ['eval_']
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_cihp2pascal.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import timeit
3
+ import numpy as np
4
+ from PIL import Image
5
+ from datetime import datetime
6
+ import os
7
+ import sys
8
+ import glob
9
+ from collections import OrderedDict
10
+ sys.path.append('../../')
11
+ # PyTorch includes
12
+ import torch
13
+ import pdb
14
+ from torch.autograd import Variable
15
+ import torch.optim as optim
16
+ from torchvision import transforms
17
+ from torch.utils.data import DataLoader
18
+ from torchvision.utils import make_grid
19
+ import cv2
20
+
21
+ # Tensorboard include
22
+ # from tensorboardX import SummaryWriter
23
+
24
+ # Custom includes
25
+ from dataloaders import pascal
26
+ from utils import util
27
+ from networks import deeplab_xception_transfer, graph
28
+ from dataloaders import custom_transforms as tr
29
+
30
+ #
31
+ import argparse
32
+ import copy
33
+ import torch.nn.functional as F
34
+ from test_from_disk import eval_
35
+
36
+
37
+ gpu_id = 1
38
+
39
+ label_colours = [(0,0,0)
40
+ # 0=background
41
+ ,(128,0,0), (0,128,0), (128,128,0), (0,0,128), (128,0,128), (0,128,128)]
42
+
43
+
44
+ def flip(x, dim):
45
+ indices = [slice(None)] * x.dim()
46
+ indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
47
+ dtype=torch.long, device=x.device)
48
+ return x[tuple(indices)]
49
+
50
+ # def flip_cihp(tail_list):
51
+ # '''
52
+ #
53
+ # :param tail_list: tail_list size is 1 x n_class x h x w
54
+ # :return:
55
+ # '''
56
+ # # tail_list = tail_list[0]
57
+ # tail_list_rev = [None] * 20
58
+ # for xx in range(14):
59
+ # tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
60
+ # tail_list_rev[14] = tail_list[15].unsqueeze(0)
61
+ # tail_list_rev[15] = tail_list[14].unsqueeze(0)
62
+ # tail_list_rev[16] = tail_list[17].unsqueeze(0)
63
+ # tail_list_rev[17] = tail_list[16].unsqueeze(0)
64
+ # tail_list_rev[18] = tail_list[19].unsqueeze(0)
65
+ # tail_list_rev[19] = tail_list[18].unsqueeze(0)
66
+ # return torch.cat(tail_list_rev,dim=0)
67
+
68
+ def decode_labels(mask, num_images=1, num_classes=20):
69
+ """Decode batch of segmentation masks.
70
+
71
+ Args:
72
+ mask: result of inference after taking argmax.
73
+ num_images: number of images to decode from the batch.
74
+ num_classes: number of classes to predict (including background).
75
+
76
+ Returns:
77
+ A batch with num_images RGB images of the same size as the input.
78
+ """
79
+ n, h, w = mask.shape
80
+ assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images)
81
+ outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
82
+ for i in range(num_images):
83
+ img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
84
+ pixels = img.load()
85
+ for j_, j in enumerate(mask[i, :, :]):
86
+ for k_, k in enumerate(j):
87
+ if k < num_classes:
88
+ pixels[k_,j_] = label_colours[k]
89
+ outputs[i] = np.array(img)
90
+ return outputs
91
+
92
+ def get_parser():
93
+ '''argparse begin'''
94
+ parser = argparse.ArgumentParser()
95
+ LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
96
+
97
+ parser.add_argument('--epochs', default=100, type=int)
98
+ parser.add_argument('--batch', default=16, type=int)
99
+ parser.add_argument('--lr', default=1e-7, type=float)
100
+ parser.add_argument('--numworker', default=12, type=int)
101
+ parser.add_argument('--step', default=30, type=int)
102
+ # parser.add_argument('--loadmodel',default=None,type=str)
103
+ parser.add_argument('--classes', default=7, type=int)
104
+ parser.add_argument('--testepoch', default=10, type=int)
105
+ parser.add_argument('--loadmodel', default='', type=str)
106
+ parser.add_argument('--txt_file', default='', type=str)
107
+ parser.add_argument('--hidden_layers', default=128, type=int)
108
+ parser.add_argument('--gpus', default=4, type=int)
109
+ parser.add_argument('--output_path', default='./results/', type=str)
110
+ parser.add_argument('--gt_path', default='./results/', type=str)
111
+ opts = parser.parse_args()
112
+ return opts
113
+
114
+
115
+ def main(opts):
116
+ adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
117
+ adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda()
118
+
119
+ adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
120
+ adj1_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
121
+
122
+ cihp_adj = graph.preprocess_adj(graph.cihp_graph)
123
+ adj3_ = Variable(torch.from_numpy(cihp_adj).float())
124
+ adj3_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
125
+
126
+ p = OrderedDict() # Parameters to include in report
127
+ p['trainBatch'] = opts.batch # Training batch size
128
+ p['nAveGrad'] = 1 # Average the gradient of several iterations
129
+ p['lr'] = opts.lr # Learning rate
130
+ p['lrFtr'] = 1e-5
131
+ p['lraspp'] = 1e-5
132
+ p['lrpro'] = 1e-5
133
+ p['lrdecoder'] = 1e-5
134
+ p['lrother'] = 1e-5
135
+ p['wd'] = 5e-4 # Weight decay
136
+ p['momentum'] = 0.9 # Momentum
137
+ p['epoch_size'] = 10 # How many epochs to change learning rate
138
+ p['num_workers'] = opts.numworker
139
+ backbone = 'xception' # Use xception or resnet as feature extractor,
140
+
141
+ with open(opts.txt_file, 'r') as f:
142
+ img_list = f.readlines()
143
+
144
+ max_id = 0
145
+ save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
146
+ exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
147
+ runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
148
+ for r in runs:
149
+ run_id = int(r.split('_')[-1])
150
+ if run_id >= max_id:
151
+ max_id = run_id + 1
152
+ # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
153
+
154
+ # Network definition
155
+ if backbone == 'xception':
156
+ net = deeplab_xception_transfer.deeplab_xception_transfer_projection(n_classes=opts.classes, os=16,
157
+ hidden_layers=opts.hidden_layers, source_classes=20,
158
+ )
159
+ elif backbone == 'resnet':
160
+ # net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
161
+ raise NotImplementedError
162
+ else:
163
+ raise NotImplementedError
164
+
165
+ if gpu_id >= 0:
166
+ net.cuda()
167
+
168
+ # net load weights
169
+ if not opts.loadmodel =='':
170
+ x = torch.load(opts.loadmodel)
171
+ net.load_source_model(x)
172
+ print('load model:' ,opts.loadmodel)
173
+ else:
174
+ print('no model load !!!!!!!!')
175
+
176
+ ## multi scale
177
+ scale_list=[1,0.5,0.75,1.25,1.5,1.75]
178
+ testloader_list = []
179
+ testloader_flip_list = []
180
+ for pv in scale_list:
181
+ composed_transforms_ts = transforms.Compose([
182
+ tr.Scale_(pv),
183
+ tr.Normalize_xception_tf(),
184
+ tr.ToTensor_()])
185
+
186
+ composed_transforms_ts_flip = transforms.Compose([
187
+ tr.Scale_(pv),
188
+ tr.HorizontalFlip(),
189
+ tr.Normalize_xception_tf(),
190
+ tr.ToTensor_()])
191
+
192
+ voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
193
+ voc_val_f = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
194
+
195
+ testloader = DataLoader(voc_val, batch_size=1, shuffle=False, num_workers=p['num_workers'])
196
+ testloader_flip = DataLoader(voc_val_f, batch_size=1, shuffle=False, num_workers=p['num_workers'])
197
+
198
+ testloader_list.append(copy.deepcopy(testloader))
199
+ testloader_flip_list.append(copy.deepcopy(testloader_flip))
200
+
201
+ print("Eval Network")
202
+
203
+ if not os.path.exists(opts.output_path + 'pascal_output_vis/'):
204
+ os.makedirs(opts.output_path + 'pascal_output_vis/')
205
+ if not os.path.exists(opts.output_path + 'pascal_output/'):
206
+ os.makedirs(opts.output_path + 'pascal_output/')
207
+
208
+ start_time = timeit.default_timer()
209
+ # One testing epoch
210
+ total_iou = 0.0
211
+ net.eval()
212
+ for ii, large_sample_batched in enumerate(zip(*testloader_list, *testloader_flip_list)):
213
+ print(ii)
214
+ #1 0.5 0.75 1.25 1.5 1.75 ; flip:
215
+ sample1 = large_sample_batched[:6]
216
+ sample2 = large_sample_batched[6:]
217
+ for iii,sample_batched in enumerate(zip(sample1,sample2)):
218
+ inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
219
+ inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
220
+ inputs = torch.cat((inputs,inputs_f),dim=0)
221
+ if iii == 0:
222
+ _,_,h,w = inputs.size()
223
+ # assert inputs.size() == inputs_f.size()
224
+
225
+ # Forward pass of the mini-batch
226
+ inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
227
+
228
+ with torch.no_grad():
229
+ if gpu_id >= 0:
230
+ inputs, labels = inputs.cuda(), labels.cuda()
231
+ # outputs = net.forward(inputs)
232
+ # pdb.set_trace()
233
+ outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
234
+ outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2
235
+ outputs = outputs.unsqueeze(0)
236
+
237
+ if iii>0:
238
+ outputs = F.upsample(outputs,size=(h,w),mode='bilinear',align_corners=True)
239
+ outputs_final = outputs_final + outputs
240
+ else:
241
+ outputs_final = outputs.clone()
242
+ ################ plot pic
243
+ predictions = torch.max(outputs_final, 1)[1]
244
+ prob_predictions = torch.max(outputs_final,1)[0]
245
+ results = predictions.cpu().numpy()
246
+ prob_results = prob_predictions.cpu().numpy()
247
+ vis_res = decode_labels(results)
248
+
249
+ parsing_im = Image.fromarray(vis_res[0])
250
+ parsing_im.save(opts.output_path + 'pascal_output_vis/{}.png'.format(img_list[ii][:-1]))
251
+ cv2.imwrite(opts.output_path + 'pascal_output/{}.png'.format(img_list[ii][:-1]), results[0,:,:])
252
+ # np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :])
253
+ # pred_list.append(predictions.cpu())
254
+ # label_list.append(labels.squeeze(1).cpu())
255
+ # loss = criterion(outputs, labels, batch_average=True)
256
+ # running_loss_ts += loss.item()
257
+
258
+ # total_iou += utils.get_iou(predictions, labels)
259
+ end_time = timeit.default_timer()
260
+ print('time use for '+str(ii) + ' is :' + str(end_time - start_time))
261
+
262
+ # Eval
263
+ pred_path = opts.output_path + 'pascal_output/'
264
+ eval_(pred_path=pred_path, gt_path=opts.gt_path,classes=opts.classes, txt_file=opts.txt_file)
265
+
266
+ if __name__ == '__main__':
267
+ opts = get_parser()
268
+ main(opts)
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/eval_show_pascal2cihp.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import timeit
3
+ import numpy as np
4
+ from PIL import Image
5
+ from datetime import datetime
6
+ import os
7
+ import sys
8
+ import glob
9
+ from collections import OrderedDict
10
+ sys.path.append('../../')
11
+ # PyTorch includes
12
+ import torch
13
+ import pdb
14
+ from torch.autograd import Variable
15
+ import torch.optim as optim
16
+ from torchvision import transforms
17
+ from torch.utils.data import DataLoader
18
+ from torchvision.utils import make_grid
19
+ import cv2
20
+
21
+ # Tensorboard include
22
+ # from tensorboardX import SummaryWriter
23
+
24
+ # Custom includes
25
+ from dataloaders import cihp
26
+ from utils import util
27
+ from networks import deeplab_xception_transfer, graph
28
+ from dataloaders import custom_transforms as tr
29
+
30
+ #
31
+ import argparse
32
+ import copy
33
+ import torch.nn.functional as F
34
+ from test_from_disk import eval_
35
+
36
+
37
+ gpu_id = 1
38
+
39
+ label_colours = [(0,0,0)
40
+ , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0)
41
+ , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
42
+
43
+
44
+ def flip(x, dim):
45
+ indices = [slice(None)] * x.dim()
46
+ indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
47
+ dtype=torch.long, device=x.device)
48
+ return x[tuple(indices)]
49
+
50
+ def flip_cihp(tail_list):
51
+ '''
52
+
53
+ :param tail_list: tail_list size is 1 x n_class x h x w
54
+ :return:
55
+ '''
56
+ # tail_list = tail_list[0]
57
+ tail_list_rev = [None] * 20
58
+ for xx in range(14):
59
+ tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
60
+ tail_list_rev[14] = tail_list[15].unsqueeze(0)
61
+ tail_list_rev[15] = tail_list[14].unsqueeze(0)
62
+ tail_list_rev[16] = tail_list[17].unsqueeze(0)
63
+ tail_list_rev[17] = tail_list[16].unsqueeze(0)
64
+ tail_list_rev[18] = tail_list[19].unsqueeze(0)
65
+ tail_list_rev[19] = tail_list[18].unsqueeze(0)
66
+ return torch.cat(tail_list_rev,dim=0)
67
+
68
+ def decode_labels(mask, num_images=1, num_classes=20):
69
+ """Decode batch of segmentation masks.
70
+
71
+ Args:
72
+ mask: result of inference after taking argmax.
73
+ num_images: number of images to decode from the batch.
74
+ num_classes: number of classes to predict (including background).
75
+
76
+ Returns:
77
+ A batch with num_images RGB images of the same size as the input.
78
+ """
79
+ n, h, w = mask.shape
80
+ assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images)
81
+ outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
82
+ for i in range(num_images):
83
+ img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
84
+ pixels = img.load()
85
+ for j_, j in enumerate(mask[i, :, :]):
86
+ for k_, k in enumerate(j):
87
+ if k < num_classes:
88
+ pixels[k_,j_] = label_colours[k]
89
+ outputs[i] = np.array(img)
90
+ return outputs
91
+
92
+ def get_parser():
93
+ '''argparse begin'''
94
+ parser = argparse.ArgumentParser()
95
+ LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
96
+
97
+ parser.add_argument('--epochs', default=100, type=int)
98
+ parser.add_argument('--batch', default=16, type=int)
99
+ parser.add_argument('--lr', default=1e-7, type=float)
100
+ parser.add_argument('--numworker', default=12, type=int)
101
+ parser.add_argument('--step', default=30, type=int)
102
+ # parser.add_argument('--loadmodel',default=None,type=str)
103
+ parser.add_argument('--classes', default=7, type=int)
104
+ parser.add_argument('--testepoch', default=10, type=int)
105
+ parser.add_argument('--loadmodel', default='', type=str)
106
+ parser.add_argument('--txt_file', default='', type=str)
107
+ parser.add_argument('--hidden_layers', default=128, type=int)
108
+ parser.add_argument('--gpus', default=4, type=int)
109
+ parser.add_argument('--output_path', default='./results/', type=str)
110
+ parser.add_argument('--gt_path', default='./results/', type=str)
111
+ opts = parser.parse_args()
112
+ return opts
113
+
114
+
115
+ def main(opts):
116
+ adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
117
+ adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).cuda().transpose(2, 3)
118
+
119
+ adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
120
+ adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()
121
+
122
+ cihp_adj = graph.preprocess_adj(graph.cihp_graph)
123
+ adj3_ = Variable(torch.from_numpy(cihp_adj).float())
124
+ adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()
125
+
126
+ p = OrderedDict() # Parameters to include in report
127
+ p['trainBatch'] = opts.batch # Training batch size
128
+ p['nAveGrad'] = 1 # Average the gradient of several iterations
129
+ p['lr'] = opts.lr # Learning rate
130
+ p['lrFtr'] = 1e-5
131
+ p['lraspp'] = 1e-5
132
+ p['lrpro'] = 1e-5
133
+ p['lrdecoder'] = 1e-5
134
+ p['lrother'] = 1e-5
135
+ p['wd'] = 5e-4 # Weight decay
136
+ p['momentum'] = 0.9 # Momentum
137
+ p['epoch_size'] = 10 # How many epochs to change learning rate
138
+ p['num_workers'] = opts.numworker
139
+ backbone = 'xception' # Use xception or resnet as feature extractor,
140
+
141
+ with open(opts.txt_file, 'r') as f:
142
+ img_list = f.readlines()
143
+
144
+ max_id = 0
145
+ save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
146
+ exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
147
+ runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
148
+ for r in runs:
149
+ run_id = int(r.split('_')[-1])
150
+ if run_id >= max_id:
151
+ max_id = run_id + 1
152
+ # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
153
+
154
+ # Network definition
155
+ if backbone == 'xception':
156
+ net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=opts.classes, os=16,
157
+ hidden_layers=opts.hidden_layers, source_classes=7,
158
+ )
159
+ elif backbone == 'resnet':
160
+ # net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
161
+ raise NotImplementedError
162
+ else:
163
+ raise NotImplementedError
164
+
165
+ if gpu_id >= 0:
166
+ net.cuda()
167
+
168
+ # net load weights
169
+ if not opts.loadmodel =='':
170
+ x = torch.load(opts.loadmodel)
171
+ net.load_source_model(x)
172
+ print('load model:' ,opts.loadmodel)
173
+ else:
174
+ print('no model load !!!!!!!!')
175
+
176
+ ## multi scale
177
+ scale_list=[1,0.5,0.75,1.25,1.5,1.75]
178
+ testloader_list = []
179
+ testloader_flip_list = []
180
+ for pv in scale_list:
181
+ composed_transforms_ts = transforms.Compose([
182
+ tr.Scale_(pv),
183
+ tr.Normalize_xception_tf(),
184
+ tr.ToTensor_()])
185
+
186
+ composed_transforms_ts_flip = transforms.Compose([
187
+ tr.Scale_(pv),
188
+ tr.HorizontalFlip(),
189
+ tr.Normalize_xception_tf(),
190
+ tr.ToTensor_()])
191
+
192
+ voc_val = cihp.VOCSegmentation(split='test', transform=composed_transforms_ts)
193
+ voc_val_f = cihp.VOCSegmentation(split='test', transform=composed_transforms_ts_flip)
194
+
195
+ testloader = DataLoader(voc_val, batch_size=1, shuffle=False, num_workers=p['num_workers'])
196
+ testloader_flip = DataLoader(voc_val_f, batch_size=1, shuffle=False, num_workers=p['num_workers'])
197
+
198
+ testloader_list.append(copy.deepcopy(testloader))
199
+ testloader_flip_list.append(copy.deepcopy(testloader_flip))
200
+
201
+ print("Eval Network")
202
+
203
+ if not os.path.exists(opts.output_path + 'cihp_output_vis/'):
204
+ os.makedirs(opts.output_path + 'cihp_output_vis/')
205
+ if not os.path.exists(opts.output_path + 'cihp_output/'):
206
+ os.makedirs(opts.output_path + 'cihp_output/')
207
+
208
+ start_time = timeit.default_timer()
209
+ # One testing epoch
210
+ total_iou = 0.0
211
+ net.eval()
212
+ for ii, large_sample_batched in enumerate(zip(*testloader_list, *testloader_flip_list)):
213
+ print(ii)
214
+ #1 0.5 0.75 1.25 1.5 1.75 ; flip:
215
+ sample1 = large_sample_batched[:6]
216
+ sample2 = large_sample_batched[6:]
217
+ for iii,sample_batched in enumerate(zip(sample1,sample2)):
218
+ inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
219
+ inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
220
+ inputs = torch.cat((inputs,inputs_f),dim=0)
221
+ if iii == 0:
222
+ _,_,h,w = inputs.size()
223
+ # assert inputs.size() == inputs_f.size()
224
+
225
+ # Forward pass of the mini-batch
226
+ inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
227
+
228
+ with torch.no_grad():
229
+ if gpu_id >= 0:
230
+ inputs, labels = inputs.cuda(), labels.cuda()
231
+ # outputs = net.forward(inputs)
232
+ # pdb.set_trace()
233
+ outputs = net.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
234
+ outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
235
+ outputs = outputs.unsqueeze(0)
236
+
237
+ if iii>0:
238
+ outputs = F.upsample(outputs,size=(h,w),mode='bilinear',align_corners=True)
239
+ outputs_final = outputs_final + outputs
240
+ else:
241
+ outputs_final = outputs.clone()
242
+ ################ plot pic
243
+ predictions = torch.max(outputs_final, 1)[1]
244
+ prob_predictions = torch.max(outputs_final,1)[0]
245
+ results = predictions.cpu().numpy()
246
+ prob_results = prob_predictions.cpu().numpy()
247
+ vis_res = decode_labels(results)
248
+
249
+ parsing_im = Image.fromarray(vis_res[0])
250
+ parsing_im.save(opts.output_path + 'cihp_output_vis/{}.png'.format(img_list[ii][:-1]))
251
+ cv2.imwrite(opts.output_path + 'cihp_output/{}.png'.format(img_list[ii][:-1]), results[0,:,:])
252
+ # np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :])
253
+ # pred_list.append(predictions.cpu())
254
+ # label_list.append(labels.squeeze(1).cpu())
255
+ # loss = criterion(outputs, labels, batch_average=True)
256
+ # running_loss_ts += loss.item()
257
+
258
+ # total_iou += utils.get_iou(predictions, labels)
259
+ end_time = timeit.default_timer()
260
+ print('time use for '+str(ii) + ' is :' + str(end_time - start_time))
261
+
262
+ # Eval
263
+ pred_path = opts.output_path + 'cihp_output/'
264
+ eval_(pred_path=pred_path, gt_path=opts.gt_path,classes=opts.classes, txt_file=opts.txt_file)
265
+
266
+ if __name__ == '__main__':
267
+ opts = get_parser()
268
+ main(opts)
TryYours-Virtual-Try-On/Graphonomy-master/exp/test/test_from_disk.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ # PyTorch includes
4
+ import torch
5
+ import numpy as np
6
+
7
+ from utils import test_human
8
+ from PIL import Image
9
+
10
+ #
11
+ import argparse
12
+
13
+ def get_parser():
14
+ '''argparse begin'''
15
+ parser = argparse.ArgumentParser()
16
+ LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
17
+
18
+ parser.add_argument('--epochs', default=100, type=int)
19
+ parser.add_argument('--batch', default=16, type=int)
20
+ parser.add_argument('--lr', default=1e-7, type=float)
21
+ parser.add_argument('--numworker',default=12,type=int)
22
+ parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
23
+ parser.add_argument('--step', default=30, type=int)
24
+ parser.add_argument('--txt_file',default=None,type=str)
25
+ parser.add_argument('--pred_path',default=None,type=str)
26
+ parser.add_argument('--gt_path',default=None,type=str)
27
+ parser.add_argument('--classes', default=7, type=int)
28
+ parser.add_argument('--testepoch', default=10, type=int)
29
+ opts = parser.parse_args()
30
+ return opts
31
+
32
+ def eval_(pred_path, gt_path, classes, txt_file):
33
+ pred_path = pred_path
34
+ gt_path = gt_path
35
+
36
+ with open(txt_file,) as f:
37
+ lines = f.readlines()
38
+ lines = [x.strip() for x in lines]
39
+
40
+ output_list = []
41
+ label_list = []
42
+ for i,file in enumerate(lines):
43
+ print(i)
44
+ file_name = file + '.png'
45
+ try:
46
+ predict_pic = np.array(Image.open(pred_path+file_name))
47
+ gt_pic = np.array(Image.open(gt_path+file_name))
48
+ output_list.append(torch.from_numpy(predict_pic))
49
+ label_list.append(torch.from_numpy(gt_pic))
50
+ except:
51
+ print(file_name,flush=True)
52
+ raise RuntimeError('no predict/gt image.')
53
+ # gt_pic = np.array(Image.open(gt_path + file_name))
54
+ # output_list.append(torch.from_numpy(gt_pic))
55
+ # label_list.append(torch.from_numpy(gt_pic))
56
+
57
+
58
+ miou = test_human.get_iou_from_list(output_list, label_list, n_cls=classes)
59
+
60
+ print('Validation:')
61
+ print('MIoU: %f\n' % miou)
62
+
63
+ if __name__ == '__main__':
64
+ opts = get_parser()
65
+ eval_(pred_path=opts.pred_path, gt_path=opts.gt_path, classes=opts.classes, txt_file=opts.txt_file)
TryYours-Virtual-Try-On/Graphonomy-master/exp/transfer/train_cihp_from_pascal.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import timeit
3
+ from datetime import datetime
4
+ import os
5
+ import sys
6
+ import glob
7
+ import numpy as np
8
+ from collections import OrderedDict
9
+ sys.path.append('../../')
10
+ sys.path.append('../../networks/')
11
+ # PyTorch includes
12
+ import torch
13
+ from torch.autograd import Variable
14
+ import torch.optim as optim
15
+ from torchvision import transforms
16
+ from torch.utils.data import DataLoader
17
+ from torchvision.utils import make_grid
18
+
19
+
20
+ # Tensorboard include
21
+ from tensorboardX import SummaryWriter
22
+
23
+ # Custom includes
24
+ from dataloaders import cihp
25
+ from utils import util,get_iou_from_list
26
+ from networks import deeplab_xception_transfer, graph
27
+ from dataloaders import custom_transforms as tr
28
+
29
+ #
30
+ import argparse
31
+
32
+ gpu_id = 0
33
+
34
+ nEpochs = 100 # Number of epochs for training
35
+ resume_epoch = 0 # Default is 0, change if want to resume
36
+
37
+ def flip(x, dim):
38
+ indices = [slice(None)] * x.dim()
39
+ indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
40
+ dtype=torch.long, device=x.device)
41
+ return x[tuple(indices)]
42
+
43
+ def flip_cihp(tail_list):
44
+ '''
45
+
46
+ :param tail_list: tail_list size is 1 x n_class x h x w
47
+ :return:
48
+ '''
49
+ # tail_list = tail_list[0]
50
+ tail_list_rev = [None] * 20
51
+ for xx in range(14):
52
+ tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
53
+ tail_list_rev[14] = tail_list[15].unsqueeze(0)
54
+ tail_list_rev[15] = tail_list[14].unsqueeze(0)
55
+ tail_list_rev[16] = tail_list[17].unsqueeze(0)
56
+ tail_list_rev[17] = tail_list[16].unsqueeze(0)
57
+ tail_list_rev[18] = tail_list[19].unsqueeze(0)
58
+ tail_list_rev[19] = tail_list[18].unsqueeze(0)
59
+ return torch.cat(tail_list_rev,dim=0)
60
+
61
+ def get_parser():
62
+ '''argparse begin'''
63
+ parser = argparse.ArgumentParser()
64
+ LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
65
+
66
+ parser.add_argument('--epochs', default=100, type=int)
67
+ parser.add_argument('--batch', default=16, type=int)
68
+ parser.add_argument('--lr', default=1e-7, type=float)
69
+ parser.add_argument('--numworker',default=12,type=int)
70
+ parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
71
+ parser.add_argument('--step', default=10, type=int)
72
+ parser.add_argument('--classes', default=20, type=int)
73
+ parser.add_argument('--testInterval', default=10, type=int)
74
+ parser.add_argument('--loadmodel',default='',type=str)
75
+ parser.add_argument('--pretrainedModel', default='', type=str)
76
+ parser.add_argument('--hidden_layers',default=128,type=int)
77
+ parser.add_argument('--gpus',default=4, type=int)
78
+
79
+ opts = parser.parse_args()
80
+ return opts
81
+
82
+ def get_graphs(opts):
83
+ adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
84
+ adj2 = adj2_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20).transpose(2, 3).cuda()
85
+ adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).transpose(2, 3)
86
+
87
+ adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
88
+ adj3 = adj1_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 7).cuda()
89
+ adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7)
90
+
91
+ # adj2 = torch.from_numpy(graph.cihp2pascal_adj).float()
92
+ # adj2 = adj2.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20)
93
+ cihp_adj = graph.preprocess_adj(graph.cihp_graph)
94
+ adj3_ = Variable(torch.from_numpy(cihp_adj).float())
95
+ adj1 = adj3_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 20).cuda()
96
+ adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20)
97
+ train_graph = [adj1, adj2, adj3]
98
+ test_graph = [adj1_test, adj2_test, adj3_test]
99
+ return train_graph, test_graph
100
+
101
+
102
+ def val_cihp(net_, testloader, testloader_flip, test_graph, epoch, writer, criterion, classes=20):
103
+ adj1_test, adj2_test, adj3_test = test_graph
104
+ num_img_ts = len(testloader)
105
+ net_.eval()
106
+ pred_list = []
107
+ label_list = []
108
+ running_loss_ts = 0.0
109
+ miou = 0
110
+ for ii, sample_batched in enumerate(zip(testloader, testloader_flip)):
111
+
112
+ inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
113
+ inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
114
+ inputs = torch.cat((inputs, inputs_f), dim=0)
115
+ # Forward pass of the mini-batch
116
+ inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
117
+ if gpu_id >= 0:
118
+ inputs, labels = inputs.cuda(), labels.cuda()
119
+
120
+ with torch.no_grad():
121
+ outputs = net_.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
122
+ # pdb.set_trace()
123
+ outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
124
+ outputs = outputs.unsqueeze(0)
125
+ predictions = torch.max(outputs, 1)[1]
126
+ pred_list.append(predictions.cpu())
127
+ label_list.append(labels.squeeze(1).cpu())
128
+ loss = criterion(outputs, labels, batch_average=True)
129
+ running_loss_ts += loss.item()
130
+ # total_iou += utils.get_iou(predictions, labels)
131
+ # Print stuff
132
+ if ii % num_img_ts == num_img_ts - 1:
133
+ # if ii == 10:
134
+ miou = get_iou_from_list(pred_list, label_list, n_cls=classes)
135
+ running_loss_ts = running_loss_ts / num_img_ts
136
+
137
+ print('Validation:')
138
+ print('[Epoch: %d, numImages: %5d]' % (epoch, ii * 1 + inputs.data.shape[0]))
139
+ writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch)
140
+ writer.add_scalar('data/test_miour', miou, epoch)
141
+ print('Loss: %f' % running_loss_ts)
142
+ print('MIoU: %f\n' % miou)
143
+
144
+
145
+ def main(opts):
146
+ p = OrderedDict() # Parameters to include in report
147
+ p['trainBatch'] = opts.batch # Training batch size
148
+ testBatch = 1 # Testing batch size
149
+ useTest = True # See evolution of the test set when training
150
+ nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs
151
+ snapshot = 1 # Store a model every snapshot epochs
152
+ p['nAveGrad'] = 1 # Average the gradient of several iterations
153
+ p['lr'] = opts.lr # Learning rate
154
+ p['lrFtr'] = 1e-5
155
+ p['lraspp'] = 1e-5
156
+ p['lrpro'] = 1e-5
157
+ p['lrdecoder'] = 1e-5
158
+ p['lrother'] = 1e-5
159
+ p['wd'] = 5e-4 # Weight decay
160
+ p['momentum'] = 0.9 # Momentum
161
+ p['epoch_size'] = opts.step # How many epochs to change learning rate
162
+ p['num_workers'] = opts.numworker
163
+ model_path = opts.pretrainedModel
164
+ backbone = 'xception' # Use xception or resnet as feature extractor,
165
+ nEpochs = opts.epochs
166
+
167
+ max_id = 0
168
+ save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
169
+ exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
170
+ runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*'))
171
+ for r in runs:
172
+ run_id = int(r.split('_')[-1])
173
+ if run_id >= max_id:
174
+ max_id = run_id + 1
175
+ save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id))
176
+
177
+ # Network definition
178
+ if backbone == 'xception':
179
+ net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=opts.classes, os=16,
180
+ hidden_layers=opts.hidden_layers, source_classes=7, )
181
+ elif backbone == 'resnet':
182
+ # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
183
+ raise NotImplementedError
184
+ else:
185
+ raise NotImplementedError
186
+
187
+ modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S')
188
+ criterion = util.cross_entropy2d
189
+
190
+ if gpu_id >= 0:
191
+ # torch.cuda.set_device(device=gpu_id)
192
+ net_.cuda()
193
+
194
+ # net load weights
195
+ if not model_path == '':
196
+ x = torch.load(model_path)
197
+ net_.load_state_dict_new(x)
198
+ print('load pretrainedModel:', model_path)
199
+ else:
200
+ print('no pretrainedModel.')
201
+ if not opts.loadmodel =='':
202
+ x = torch.load(opts.loadmodel)
203
+ net_.load_source_model(x)
204
+ print('load model:' ,opts.loadmodel)
205
+ else:
206
+ print('no model load !!!!!!!!')
207
+
208
+ log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
209
+ writer = SummaryWriter(log_dir=log_dir)
210
+ writer.add_text('load model',opts.loadmodel,1)
211
+ writer.add_text('setting',sys.argv[0],1)
212
+
213
+ if opts.freezeBN:
214
+ net_.freeze_bn()
215
+
216
+ # Use the following optimizer
217
+ optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
218
+
219
+ composed_transforms_tr = transforms.Compose([
220
+ tr.RandomSized_new(512),
221
+ tr.Normalize_xception_tf(),
222
+ tr.ToTensor_()])
223
+
224
+ composed_transforms_ts = transforms.Compose([
225
+ tr.Normalize_xception_tf(),
226
+ tr.ToTensor_()])
227
+
228
+ composed_transforms_ts_flip = transforms.Compose([
229
+ tr.HorizontalFlip(),
230
+ tr.Normalize_xception_tf(),
231
+ tr.ToTensor_()])
232
+
233
+ voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
234
+ voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts)
235
+ voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
236
+
237
+ trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=p['num_workers'],drop_last=True)
238
+ testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
239
+ testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
240
+
241
+ num_img_tr = len(trainloader)
242
+ num_img_ts = len(testloader)
243
+ running_loss_tr = 0.0
244
+ running_loss_ts = 0.0
245
+ aveGrad = 0
246
+ global_step = 0
247
+ print("Training Network")
248
+
249
+ net = torch.nn.DataParallel(net_)
250
+ train_graph, test_graph = get_graphs(opts)
251
+ adj1, adj2, adj3 = train_graph
252
+
253
+
254
+ # Main Training and Testing Loop
255
+ for epoch in range(resume_epoch, nEpochs):
256
+ start_time = timeit.default_timer()
257
+
258
+ if epoch % p['epoch_size'] == p['epoch_size'] - 1:
259
+ lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9)
260
+ optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
261
+ writer.add_scalar('data/lr_', lr_, epoch)
262
+ print('(poly lr policy) learning rate: ', lr_)
263
+
264
+ net.train()
265
+ for ii, sample_batched in enumerate(trainloader):
266
+
267
+ inputs, labels = sample_batched['image'], sample_batched['label']
268
+ # Forward-Backward of the mini-batch
269
+ inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
270
+ global_step += inputs.data.shape[0]
271
+
272
+ if gpu_id >= 0:
273
+ inputs, labels = inputs.cuda(), labels.cuda()
274
+
275
+ outputs = net.forward(inputs, adj1, adj3, adj2)
276
+
277
+ loss = criterion(outputs, labels, batch_average=True)
278
+ running_loss_tr += loss.item()
279
+
280
+ # Print stuff
281
+ if ii % num_img_tr == (num_img_tr - 1):
282
+ running_loss_tr = running_loss_tr / num_img_tr
283
+ writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
284
+ print('[Epoch: %d, numImages: %5d]' % (epoch, ii * p['trainBatch'] + inputs.data.shape[0]))
285
+ print('Loss: %f' % running_loss_tr)
286
+ running_loss_tr = 0
287
+ stop_time = timeit.default_timer()
288
+ print("Execution time: " + str(stop_time - start_time) + "\n")
289
+
290
+ # Backward the averaged gradient
291
+ loss /= p['nAveGrad']
292
+ loss.backward()
293
+ aveGrad += 1
294
+
295
+ # Update the weights once in p['nAveGrad'] forward passes
296
+ if aveGrad % p['nAveGrad'] == 0:
297
+ writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)
298
+ optimizer.step()
299
+ optimizer.zero_grad()
300
+ aveGrad = 0
301
+
302
+ # Show 10 * 3 images results each epoch
303
+ if ii % (num_img_tr // 10) == 0:
304
+ grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
305
+ writer.add_image('Image', grid_image, global_step)
306
+ grid_image = make_grid(util.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False,
307
+ range=(0, 255))
308
+ writer.add_image('Predicted label', grid_image, global_step)
309
+ grid_image = make_grid(util.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255))
310
+ writer.add_image('Groundtruth label', grid_image, global_step)
311
+ print('loss is ', loss.cpu().item(), flush=True)
312
+
313
+ # Save the model
314
+ if (epoch % snapshot) == snapshot - 1:
315
+ torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
316
+ print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')))
317
+
318
+ torch.cuda.empty_cache()
319
+
320
+ # One testing epoch
321
+ if useTest and epoch % nTestInterval == (nTestInterval - 1):
322
+ val_cihp(net_,testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph,
323
+ epoch=epoch,writer=writer,criterion=criterion, classes=opts.classes)
324
+ torch.cuda.empty_cache()
325
+
326
+
327
+
328
+
329
+ if __name__ == '__main__':
330
+ opts = get_parser()
331
+ main(opts)
TryYours-Virtual-Try-On/Graphonomy-master/exp/universal/pascal_atr_cihp_uni.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import timeit
3
+ from datetime import datetime
4
+ import os
5
+ import sys
6
+ import glob
7
+ import numpy as np
8
+ from collections import OrderedDict
9
+ sys.path.append('./')
10
+ sys.path.append('./networks/')
11
+ # PyTorch includes
12
+ import torch
13
+ from torch.autograd import Variable
14
+ import torch.optim as optim
15
+ from torchvision import transforms
16
+ from torch.utils.data import DataLoader
17
+ from torchvision.utils import make_grid
18
+ import random
19
+
20
+ # Tensorboard include
21
+ from tensorboardX import SummaryWriter
22
+
23
+ # Custom includes
24
+ from dataloaders import pascal, cihp_pascal_atr
25
+ from utils import get_iou_from_list
26
+ from utils import util as ut
27
+ from networks import deeplab_xception_universal, graph
28
+ from dataloaders import custom_transforms as tr
29
+ from utils import sampler as sam
30
+ #
31
+ import argparse
32
+
33
+ '''
34
+ source is cihp
35
+ target is pascal
36
+ '''
37
+
38
+ gpu_id = 1
39
+ # print('Using GPU: {} '.format(gpu_id))
40
+
41
+ # nEpochs = 100 # Number of epochs for training
42
+ resume_epoch = 0 # Default is 0, change if want to resume
43
+
44
+ def flip(x, dim):
45
+ indices = [slice(None)] * x.dim()
46
+ indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
47
+ dtype=torch.long, device=x.device)
48
+ return x[tuple(indices)]
49
+
50
+ def flip_cihp(tail_list):
51
+ '''
52
+
53
+ :param tail_list: tail_list size is 1 x n_class x h x w
54
+ :return:
55
+ '''
56
+ # tail_list = tail_list[0]
57
+ tail_list_rev = [None] * 20
58
+ for xx in range(14):
59
+ tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
60
+ tail_list_rev[14] = tail_list[15].unsqueeze(0)
61
+ tail_list_rev[15] = tail_list[14].unsqueeze(0)
62
+ tail_list_rev[16] = tail_list[17].unsqueeze(0)
63
+ tail_list_rev[17] = tail_list[16].unsqueeze(0)
64
+ tail_list_rev[18] = tail_list[19].unsqueeze(0)
65
+ tail_list_rev[19] = tail_list[18].unsqueeze(0)
66
+ return torch.cat(tail_list_rev,dim=0)
67
+
68
+ def get_parser():
69
+ '''argparse begin'''
70
+ parser = argparse.ArgumentParser()
71
+ LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
72
+
73
+ parser.add_argument('--epochs', default=100, type=int)
74
+ parser.add_argument('--batch', default=16, type=int)
75
+ parser.add_argument('--lr', default=1e-7, type=float)
76
+ parser.add_argument('--numworker',default=12,type=int)
77
+ # parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
78
+ parser.add_argument('--step', default=10, type=int)
79
+ # parser.add_argument('--loadmodel',default=None,type=str)
80
+ parser.add_argument('--classes', default=7, type=int)
81
+ parser.add_argument('--testepoch', default=10, type=int)
82
+ parser.add_argument('--loadmodel',default='',type=str)
83
+ parser.add_argument('--pretrainedModel', default='', type=str)
84
+ parser.add_argument('--hidden_layers',default=128,type=int)
85
+ parser.add_argument('--gpus',default=4, type=int)
86
+ parser.add_argument('--testInterval', default=5, type=int)
87
+ opts = parser.parse_args()
88
+ return opts
89
+
90
+ def get_graphs(opts):
91
+ '''source is pascal; target is cihp; middle is atr'''
92
+ # target 1
93
+ cihp_adj = graph.preprocess_adj(graph.cihp_graph)
94
+ adj1_ = Variable(torch.from_numpy(cihp_adj).float())
95
+ adj1 = adj1_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 20).cuda()
96
+ adj1_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20)
97
+ #source 2
98
+ adj2_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
99
+ adj2 = adj2_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 7).cuda()
100
+ adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7)
101
+ # s to target 3
102
+ adj3_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
103
+ adj3 = adj3_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20).transpose(2,3).cuda()
104
+ adj3_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).transpose(2,3)
105
+ # middle 4
106
+ atr_adj = graph.preprocess_adj(graph.atr_graph)
107
+ adj4_ = Variable(torch.from_numpy(atr_adj).float())
108
+ adj4 = adj4_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 18, 18).cuda()
109
+ adj4_test = adj4_.unsqueeze(0).unsqueeze(0).expand(1, 1, 18, 18)
110
+ # source to middle 5
111
+ adj5_ = torch.from_numpy(graph.pascal2atr_nlp_adj).float()
112
+ adj5 = adj5_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 18).cuda()
113
+ adj5_test = adj5_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 18)
114
+ # target to middle 6
115
+ adj6_ = torch.from_numpy(graph.cihp2atr_nlp_adj).float()
116
+ adj6 = adj6_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 18).cuda()
117
+ adj6_test = adj6_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 18)
118
+ train_graph = [adj1, adj2, adj3, adj4, adj5, adj6]
119
+ test_graph = [adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test]
120
+ return train_graph, test_graph
121
+
122
+
123
+ def main(opts):
124
+ # Set parameters
125
+ p = OrderedDict() # Parameters to include in report
126
+ p['trainBatch'] = opts.batch # Training batch size
127
+ testBatch = 1 # Testing batch size
128
+ useTest = True # See evolution of the test set when training
129
+ nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs
130
+ snapshot = 1 # Store a model every snapshot epochs
131
+ p['nAveGrad'] = 1 # Average the gradient of several iterations
132
+ p['lr'] = opts.lr # Learning rate
133
+ p['wd'] = 5e-4 # Weight decay
134
+ p['momentum'] = 0.9 # Momentum
135
+ p['epoch_size'] = opts.step # How many epochs to change learning rate
136
+ p['num_workers'] = opts.numworker
137
+ model_path = opts.pretrainedModel
138
+ backbone = 'xception' # Use xception or resnet as feature extractor
139
+ nEpochs = opts.epochs
140
+
141
+ max_id = 0
142
+ save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
143
+ exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
144
+ runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
145
+ for r in runs:
146
+ run_id = int(r.split('_')[-1])
147
+ if run_id >= max_id:
148
+ max_id = run_id + 1
149
+ # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
150
+ save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(max_id))
151
+
152
+ # Network definition
153
+ if backbone == 'xception':
154
+ net_ = deeplab_xception_universal.deeplab_xception_end2end_3d(n_classes=20, os=16,
155
+ hidden_layers=opts.hidden_layers,
156
+ source_classes=7,
157
+ middle_classes=18, )
158
+ elif backbone == 'resnet':
159
+ # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
160
+ raise NotImplementedError
161
+ else:
162
+ raise NotImplementedError
163
+
164
+ modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S')
165
+ criterion = ut.cross_entropy2d
166
+
167
+ if gpu_id >= 0:
168
+ # torch.cuda.set_device(device=gpu_id)
169
+ net_.cuda()
170
+
171
+ # net load weights
172
+ if not model_path == '':
173
+ x = torch.load(model_path)
174
+ net_.load_state_dict_new(x)
175
+ print('load pretrainedModel.')
176
+ else:
177
+ print('no pretrainedModel.')
178
+
179
+ if not opts.loadmodel =='':
180
+ x = torch.load(opts.loadmodel)
181
+ net_.load_source_model(x)
182
+ print('load model:' ,opts.loadmodel)
183
+ else:
184
+ print('no trained model load !!!!!!!!')
185
+
186
+ log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
187
+ writer = SummaryWriter(log_dir=log_dir)
188
+ writer.add_text('load model',opts.loadmodel,1)
189
+ writer.add_text('setting',sys.argv[0],1)
190
+
191
+ # Use the following optimizer
192
+ optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
193
+
194
+ composed_transforms_tr = transforms.Compose([
195
+ tr.RandomSized_new(512),
196
+ tr.Normalize_xception_tf(),
197
+ tr.ToTensor_()])
198
+
199
+ composed_transforms_ts = transforms.Compose([
200
+ tr.Normalize_xception_tf(),
201
+ tr.ToTensor_()])
202
+
203
+ composed_transforms_ts_flip = transforms.Compose([
204
+ tr.HorizontalFlip(),
205
+ tr.Normalize_xception_tf(),
206
+ tr.ToTensor_()])
207
+
208
+ all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
209
+ voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
210
+ voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
211
+
212
+ num_cihp,num_pascal,num_atr = all_train.get_class_num()
213
+ ss = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch)
214
+ # balance datasets based pascal
215
+ ss_balanced = sam.Sampler_uni(num_cihp,num_pascal,num_atr,opts.batch, balance_id=1)
216
+
217
+ trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'],
218
+ sampler=ss, drop_last=True)
219
+ trainloader_balanced = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=p['num_workers'],
220
+ sampler=ss_balanced, drop_last=True)
221
+ testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
222
+ testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
223
+
224
+ num_img_tr = len(trainloader)
225
+ num_img_balanced = len(trainloader_balanced)
226
+ num_img_ts = len(testloader)
227
+ running_loss_tr = 0.0
228
+ running_loss_tr_atr = 0.0
229
+ running_loss_ts = 0.0
230
+ aveGrad = 0
231
+ global_step = 0
232
+ print("Training Network")
233
+ net = torch.nn.DataParallel(net_)
234
+
235
+ id_list = torch.LongTensor(range(opts.batch))
236
+ pascal_iter = int(num_img_tr//opts.batch)
237
+
238
+ # Get graphs
239
+ train_graph, test_graph = get_graphs(opts)
240
+ adj1, adj2, adj3, adj4, adj5, adj6 = train_graph
241
+ adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph
242
+
243
+ # Main Training and Testing Loop
244
+ for epoch in range(resume_epoch, int(1.5*nEpochs)):
245
+ start_time = timeit.default_timer()
246
+
247
+ if epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch<nEpochs:
248
+ lr_ = ut.lr_poly(p['lr'], epoch, nEpochs, 0.9)
249
+ optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
250
+ print('(poly lr policy) learning rate: ', lr_)
251
+ writer.add_scalar('data/lr_',lr_,epoch)
252
+ elif epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch > nEpochs:
253
+ lr_ = ut.lr_poly(p['lr'], epoch-nEpochs, int(0.5*nEpochs), 0.9)
254
+ optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
255
+ print('(poly lr policy) learning rate: ', lr_)
256
+ writer.add_scalar('data/lr_', lr_, epoch)
257
+
258
+ net_.train()
259
+ if epoch < nEpochs:
260
+ for ii, sample_batched in enumerate(trainloader):
261
+ inputs, labels = sample_batched['image'], sample_batched['label']
262
+ dataset_lbl = sample_batched['pascal'][0].item()
263
+ # Forward-Backward of the mini-batch
264
+ inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
265
+ global_step += 1
266
+
267
+ if gpu_id >= 0:
268
+ inputs, labels = inputs.cuda(), labels.cuda()
269
+
270
+ if dataset_lbl == 0:
271
+ # 0 is cihp -- target
272
+ _, outputs,_ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1, adj2_source=adj2,
273
+ adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2,3), adj4_middle=adj4,adj5_transfer_s2m=adj5.transpose(2, 3),
274
+ adj6_transfer_t2m=adj6.transpose(2, 3),adj5_transfer_m2s=adj5,adj6_transfer_m2t=adj6,)
275
+ elif dataset_lbl == 1:
276
+ # pascal is source
277
+ outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1,
278
+ adj2_source=adj2,
279
+ adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
280
+ adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
281
+ adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
282
+ adj6_transfer_m2t=adj6, )
283
+ else:
284
+ # atr
285
+ _, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1,
286
+ adj2_source=adj2,
287
+ adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
288
+ adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
289
+ adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
290
+ adj6_transfer_m2t=adj6, )
291
+ # print(sample_batched['pascal'])
292
+ # print(outputs.size(),)
293
+ # print(labels)
294
+ loss = criterion(outputs, labels, batch_average=True)
295
+ running_loss_tr += loss.item()
296
+
297
+ # Print stuff
298
+ if ii % num_img_tr == (num_img_tr - 1):
299
+ running_loss_tr = running_loss_tr / num_img_tr
300
+ writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
301
+ print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
302
+ print('Loss: %f' % running_loss_tr)
303
+ running_loss_tr = 0
304
+ stop_time = timeit.default_timer()
305
+ print("Execution time: " + str(stop_time - start_time) + "\n")
306
+
307
+ # Backward the averaged gradient
308
+ loss /= p['nAveGrad']
309
+ loss.backward()
310
+ aveGrad += 1
311
+
312
+ # Update the weights once in p['nAveGrad'] forward passes
313
+ if aveGrad % p['nAveGrad'] == 0:
314
+ writer.add_scalar('data/total_loss_iter', loss.item(), global_step)
315
+ if dataset_lbl == 0:
316
+ writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
317
+ if dataset_lbl == 1:
318
+ writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
319
+ if dataset_lbl == 2:
320
+ writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)
321
+ optimizer.step()
322
+ optimizer.zero_grad()
323
+ # optimizer_gcn.step()
324
+ # optimizer_gcn.zero_grad()
325
+ aveGrad = 0
326
+
327
+ # Show 10 * 3 images results each epoch
328
+ if ii % (num_img_tr // 10) == 0:
329
+ grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
330
+ writer.add_image('Image', grid_image, global_step)
331
+ grid_image = make_grid(ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False,
332
+ range=(0, 255))
333
+ writer.add_image('Predicted label', grid_image, global_step)
334
+ grid_image = make_grid(ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255))
335
+ writer.add_image('Groundtruth label', grid_image, global_step)
336
+
337
+ print('loss is ',loss.cpu().item(),flush=True)
338
+ else:
339
+ # Balanced the number of datasets
340
+ for ii, sample_batched in enumerate(trainloader_balanced):
341
+ inputs, labels = sample_batched['image'], sample_batched['label']
342
+ dataset_lbl = sample_batched['pascal'][0].item()
343
+ # Forward-Backward of the mini-batch
344
+ inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
345
+ global_step += 1
346
+
347
+ if gpu_id >= 0:
348
+ inputs, labels = inputs.cuda(), labels.cuda()
349
+
350
+ if dataset_lbl == 0:
351
+ # 0 is cihp -- target
352
+ _, outputs, _ = net.forward(None, input_target=inputs, input_middle=None, adj1_target=adj1,
353
+ adj2_source=adj2,
354
+ adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
355
+ adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
356
+ adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
357
+ adj6_transfer_m2t=adj6, )
358
+ elif dataset_lbl == 1:
359
+ # pascal is source
360
+ outputs, _, _ = net.forward(inputs, input_target=None, input_middle=None, adj1_target=adj1,
361
+ adj2_source=adj2,
362
+ adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
363
+ adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
364
+ adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
365
+ adj6_transfer_m2t=adj6, )
366
+ else:
367
+ # atr
368
+ _, _, outputs = net.forward(None, input_target=None, input_middle=inputs, adj1_target=adj1,
369
+ adj2_source=adj2,
370
+ adj3_transfer_s2t=adj3, adj3_transfer_t2s=adj3.transpose(2, 3),
371
+ adj4_middle=adj4, adj5_transfer_s2m=adj5.transpose(2, 3),
372
+ adj6_transfer_t2m=adj6.transpose(2, 3), adj5_transfer_m2s=adj5,
373
+ adj6_transfer_m2t=adj6, )
374
+ # print(sample_batched['pascal'])
375
+ # print(outputs.size(),)
376
+ # print(labels)
377
+ loss = criterion(outputs, labels, batch_average=True)
378
+ running_loss_tr += loss.item()
379
+
380
+ # Print stuff
381
+ if ii % num_img_balanced == (num_img_balanced - 1):
382
+ running_loss_tr = running_loss_tr / num_img_balanced
383
+ writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
384
+ print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
385
+ print('Loss: %f' % running_loss_tr)
386
+ running_loss_tr = 0
387
+ stop_time = timeit.default_timer()
388
+ print("Execution time: " + str(stop_time - start_time) + "\n")
389
+
390
+ # Backward the averaged gradient
391
+ loss /= p['nAveGrad']
392
+ loss.backward()
393
+ aveGrad += 1
394
+
395
+ # Update the weights once in p['nAveGrad'] forward passes
396
+ if aveGrad % p['nAveGrad'] == 0:
397
+ writer.add_scalar('data/total_loss_iter', loss.item(), global_step)
398
+ if dataset_lbl == 0:
399
+ writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
400
+ if dataset_lbl == 1:
401
+ writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
402
+ if dataset_lbl == 2:
403
+ writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)
404
+ optimizer.step()
405
+ optimizer.zero_grad()
406
+
407
+ aveGrad = 0
408
+
409
+ # Show 10 * 3 images results each epoch
410
+ if ii % (num_img_balanced // 10) == 0:
411
+ grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
412
+ writer.add_image('Image', grid_image, global_step)
413
+ grid_image = make_grid(
414
+ ut.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3,
415
+ normalize=False,
416
+ range=(0, 255))
417
+ writer.add_image('Predicted label', grid_image, global_step)
418
+ grid_image = make_grid(
419
+ ut.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3,
420
+ normalize=False, range=(0, 255))
421
+ writer.add_image('Groundtruth label', grid_image, global_step)
422
+
423
+ print('loss is ', loss.cpu().item(), flush=True)
424
+
425
+ # Save the model
426
+ if (epoch % snapshot) == snapshot - 1:
427
+ torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
428
+ print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')))
429
+
430
+ # One testing epoch
431
+ if useTest and epoch % nTestInterval == (nTestInterval - 1):
432
+ val_pascal(net_=net_, testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph,
433
+ criterion=criterion, epoch=epoch, writer=writer)
434
+
435
+
436
+ def val_pascal(net_, testloader, testloader_flip, test_graph, criterion, epoch, writer, classes=7):
437
+ running_loss_ts = 0.0
438
+ miou = 0
439
+ adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph
440
+ num_img_ts = len(testloader)
441
+ net_.eval()
442
+ pred_list = []
443
+ label_list = []
444
+ for ii, sample_batched in enumerate(zip(testloader, testloader_flip)):
445
+ # print(ii)
446
+ inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
447
+ inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
448
+ inputs = torch.cat((inputs, inputs_f), dim=0)
449
+ # Forward pass of the mini-batch
450
+ inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
451
+
452
+ with torch.no_grad():
453
+ if gpu_id >= 0:
454
+ inputs, labels = inputs.cuda(), labels.cuda()
455
+ outputs, _, _ = net_.forward(inputs, input_target=None, input_middle=None,
456
+ adj1_target=adj1_test.cuda(),
457
+ adj2_source=adj2_test.cuda(),
458
+ adj3_transfer_s2t=adj3_test.cuda(),
459
+ adj3_transfer_t2s=adj3_test.transpose(2, 3).cuda(),
460
+ adj4_middle=adj4_test.cuda(),
461
+ adj5_transfer_s2m=adj5_test.transpose(2, 3).cuda(),
462
+ adj6_transfer_t2m=adj6_test.transpose(2, 3).cuda(),
463
+ adj5_transfer_m2s=adj5_test.cuda(),
464
+ adj6_transfer_m2t=adj6_test.cuda(), )
465
+ # pdb.set_trace()
466
+ outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2
467
+ outputs = outputs.unsqueeze(0)
468
+ predictions = torch.max(outputs, 1)[1]
469
+ pred_list.append(predictions.cpu())
470
+ label_list.append(labels.squeeze(1).cpu())
471
+ loss = criterion(outputs, labels, batch_average=True)
472
+ running_loss_ts += loss.item()
473
+
474
+ # total_iou += utils.get_iou(predictions, labels)
475
+
476
+ # Print stuff
477
+ if ii % num_img_ts == num_img_ts - 1:
478
+ # if ii == 10:
479
+ miou = get_iou_from_list(pred_list, label_list, n_cls=classes)
480
+ running_loss_ts = running_loss_ts / num_img_ts
481
+
482
+ print('Validation:')
483
+ print('[Epoch: %d, numImages: %5d]' % (epoch, ii * 1 + inputs.data.shape[0]))
484
+ writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch)
485
+ writer.add_scalar('data/test_miour', miou, epoch)
486
+ print('Loss: %f' % running_loss_ts)
487
+ print('MIoU: %f\n' % miou)
488
+ # return miou
489
+
490
+
491
+ if __name__ == '__main__':
492
+ opts = get_parser()
493
+ main(opts)
TryYours-Virtual-Try-On/Graphonomy-master/inference.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python exp/inference/inference.py --loadmodel ./data/pretrained_model/inference.pth --img_path ./img/messi.jpg --output_path ./img/ --output_name /output_file_name
TryYours-Virtual-Try-On/Graphonomy-master/networks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .deeplab_xception import *
2
+ from .deeplab_xception_transfer import *
3
+ from .deeplab_xception_universal import *
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (282 Bytes). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (258 Bytes). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception.cpython-39.pyc ADDED
Binary file (17.1 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_synBN.cpython-39.pyc ADDED
Binary file (15.4 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_transfer.cpython-39.pyc ADDED
Binary file (19.1 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/deeplab_xception_universal.cpython-39.pyc ADDED
Binary file (20.3 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-310.pyc ADDED
Binary file (8.3 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/gcn.cpython-39.pyc ADDED
Binary file (8.46 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-310.pyc ADDED
Binary file (9.09 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/__pycache__/graph.cpython-39.pyc ADDED
Binary file (9 kB). View file
 
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn.parameter import Parameter
7
+ from collections import OrderedDict
8
+
9
+ class SeparableConv2d(nn.Module):
10
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False):
11
+ super(SeparableConv2d, self).__init__()
12
+
13
+ self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
14
+ groups=inplanes, bias=bias)
15
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
16
+
17
+ def forward(self, x):
18
+ x = self.conv1(x)
19
+ x = self.pointwise(x)
20
+ return x
21
+
22
+
23
+ def fixed_padding(inputs, kernel_size, rate):
24
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
25
+ pad_total = kernel_size_effective - 1
26
+ pad_beg = pad_total // 2
27
+ pad_end = pad_total - pad_beg
28
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
29
+ return padded_inputs
30
+
31
+
32
+ class SeparableConv2d_aspp(nn.Module):
33
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
34
+ super(SeparableConv2d_aspp, self).__init__()
35
+
36
+ self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
37
+ groups=inplanes, bias=bias)
38
+ self.depthwise_bn = nn.BatchNorm2d(inplanes)
39
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
40
+ self.pointwise_bn = nn.BatchNorm2d(planes)
41
+ self.relu = nn.ReLU()
42
+
43
+ def forward(self, x):
44
+ # x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
45
+ x = self.depthwise(x)
46
+ x = self.depthwise_bn(x)
47
+ x = self.relu(x)
48
+ x = self.pointwise(x)
49
+ x = self.pointwise_bn(x)
50
+ x = self.relu(x)
51
+ return x
52
+
53
+ class Decoder_module(nn.Module):
54
+ def __init__(self, inplanes, planes, rate=1):
55
+ super(Decoder_module, self).__init__()
56
+ self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1)
57
+
58
+ def forward(self, x):
59
+ x = self.atrous_convolution(x)
60
+ return x
61
+
62
+ class ASPP_module(nn.Module):
63
+ def __init__(self, inplanes, planes, rate):
64
+ super(ASPP_module, self).__init__()
65
+ if rate == 1:
66
+ raise RuntimeError()
67
+ else:
68
+ kernel_size = 3
69
+ padding = rate
70
+ self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,
71
+ padding=padding)
72
+
73
+ def forward(self, x):
74
+ x = self.atrous_convolution(x)
75
+ return x
76
+
77
+ class ASPP_module_rate0(nn.Module):
78
+ def __init__(self, inplanes, planes, rate=1):
79
+ super(ASPP_module_rate0, self).__init__()
80
+ if rate == 1:
81
+ kernel_size = 1
82
+ padding = 0
83
+ self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
84
+ stride=1, padding=padding, dilation=rate, bias=False)
85
+ self.bn = nn.BatchNorm2d(planes, eps=1e-5, affine=True)
86
+ self.relu = nn.ReLU()
87
+ else:
88
+ raise RuntimeError()
89
+
90
+ def forward(self, x):
91
+ x = self.atrous_convolution(x)
92
+ x = self.bn(x)
93
+ return self.relu(x)
94
+
95
+ class SeparableConv2d_same(nn.Module):
96
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
97
+ super(SeparableConv2d_same, self).__init__()
98
+
99
+ self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
100
+ groups=inplanes, bias=bias)
101
+ self.depthwise_bn = nn.BatchNorm2d(inplanes)
102
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
103
+ self.pointwise_bn = nn.BatchNorm2d(planes)
104
+
105
+ def forward(self, x):
106
+ x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
107
+ x = self.depthwise(x)
108
+ x = self.depthwise_bn(x)
109
+ x = self.pointwise(x)
110
+ x = self.pointwise_bn(x)
111
+ return x
112
+
113
+ class Block(nn.Module):
114
+ def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
115
+ super(Block, self).__init__()
116
+
117
+ if planes != inplanes or stride != 1:
118
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False)
119
+ if is_last:
120
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False)
121
+ self.skipbn = nn.BatchNorm2d(planes)
122
+ else:
123
+ self.skip = None
124
+
125
+ self.relu = nn.ReLU(inplace=True)
126
+ rep = []
127
+
128
+ filters = inplanes
129
+ if grow_first:
130
+ rep.append(self.relu)
131
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
132
+ # rep.append(nn.BatchNorm2d(planes))
133
+ filters = planes
134
+
135
+ for i in range(reps - 1):
136
+ rep.append(self.relu)
137
+ rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
138
+ # rep.append(nn.BatchNorm2d(filters))
139
+
140
+ if not grow_first:
141
+ rep.append(self.relu)
142
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
143
+ # rep.append(nn.BatchNorm2d(planes))
144
+
145
+ if not start_with_relu:
146
+ rep = rep[1:]
147
+
148
+ if stride != 1:
149
+ rep.append(self.relu)
150
+ rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation))
151
+
152
+ if is_last:
153
+ rep.append(self.relu)
154
+ rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation))
155
+
156
+
157
+ self.rep = nn.Sequential(*rep)
158
+
159
+ def forward(self, inp):
160
+ x = self.rep(inp)
161
+
162
+ if self.skip is not None:
163
+ skip = self.skip(inp)
164
+ skip = self.skipbn(skip)
165
+ else:
166
+ skip = inp
167
+ # print(x.size(),skip.size())
168
+ x += skip
169
+
170
+ return x
171
+
172
+ class Block2(nn.Module):
173
+ def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
174
+ super(Block2, self).__init__()
175
+
176
+ if planes != inplanes or stride != 1:
177
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
178
+ self.skipbn = nn.BatchNorm2d(planes)
179
+ else:
180
+ self.skip = None
181
+
182
+ self.relu = nn.ReLU(inplace=True)
183
+ rep = []
184
+
185
+ filters = inplanes
186
+ if grow_first:
187
+ rep.append(self.relu)
188
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
189
+ # rep.append(nn.BatchNorm2d(planes))
190
+ filters = planes
191
+
192
+ for i in range(reps - 1):
193
+ rep.append(self.relu)
194
+ rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
195
+ # rep.append(nn.BatchNorm2d(filters))
196
+
197
+ if not grow_first:
198
+ rep.append(self.relu)
199
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
200
+ # rep.append(nn.BatchNorm2d(planes))
201
+
202
+ if not start_with_relu:
203
+ rep = rep[1:]
204
+
205
+ if stride != 1:
206
+ self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)])
207
+
208
+ if is_last:
209
+ rep.append(SeparableConv2d_same(planes, planes, 3, stride=1))
210
+
211
+
212
+ self.rep = nn.Sequential(*rep)
213
+
214
+ def forward(self, inp):
215
+ x = self.rep(inp)
216
+ low_middle = x.clone()
217
+ x1 = x
218
+ x1 = self.block2_lastconv(x1)
219
+ if self.skip is not None:
220
+ skip = self.skip(inp)
221
+ skip = self.skipbn(skip)
222
+ else:
223
+ skip = inp
224
+
225
+ x1 += skip
226
+
227
+ return x1,low_middle
228
+
229
+ class Xception(nn.Module):
230
+ """
231
+ Modified Alighed Xception
232
+ """
233
+ def __init__(self, inplanes=3, os=16, pretrained=False):
234
+ super(Xception, self).__init__()
235
+
236
+ if os == 16:
237
+ entry_block3_stride = 2
238
+ middle_block_rate = 1
239
+ exit_block_rates = (1, 2)
240
+ elif os == 8:
241
+ entry_block3_stride = 1
242
+ middle_block_rate = 2
243
+ exit_block_rates = (2, 4)
244
+ else:
245
+ raise NotImplementedError
246
+
247
+
248
+ # Entry flow
249
+ self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False)
250
+ self.bn1 = nn.BatchNorm2d(32)
251
+ self.relu = nn.ReLU(inplace=True)
252
+
253
+ self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
254
+ self.bn2 = nn.BatchNorm2d(64)
255
+
256
+ self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)
257
+ self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
258
+ self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True)
259
+
260
+ # Middle flow
261
+ self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
262
+ self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
263
+ self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
264
+ self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
265
+ self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
266
+ self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
267
+ self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
268
+ self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
269
+ self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
270
+ self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
271
+ self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
272
+ self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
273
+ self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
274
+ self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
275
+ self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
276
+ self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
277
+
278
+ # Exit flow
279
+ self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0],
280
+ start_with_relu=True, grow_first=False, is_last=True)
281
+
282
+ self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
283
+ # self.bn3 = nn.BatchNorm2d(1536)
284
+
285
+ self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
286
+ # self.bn4 = nn.BatchNorm2d(1536)
287
+
288
+ self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
289
+ # self.bn5 = nn.BatchNorm2d(2048)
290
+
291
+ # Init weights
292
+ # self.__init_weight()
293
+
294
+ # Load pretrained model
295
+ if pretrained:
296
+ self.__load_xception_pretrained()
297
+
298
+ def forward(self, x):
299
+ # Entry flow
300
+ x = self.conv1(x)
301
+ x = self.bn1(x)
302
+ x = self.relu(x)
303
+ # print('conv1 ',x.size())
304
+ x = self.conv2(x)
305
+ x = self.bn2(x)
306
+ x = self.relu(x)
307
+
308
+ x = self.block1(x)
309
+ # print('block1',x.size())
310
+ # low_level_feat = x
311
+ x,low_level_feat = self.block2(x)
312
+ # print('block2',x.size())
313
+ x = self.block3(x)
314
+ # print('xception block3 ',x.size())
315
+
316
+ # Middle flow
317
+ x = self.block4(x)
318
+ x = self.block5(x)
319
+ x = self.block6(x)
320
+ x = self.block7(x)
321
+ x = self.block8(x)
322
+ x = self.block9(x)
323
+ x = self.block10(x)
324
+ x = self.block11(x)
325
+ x = self.block12(x)
326
+ x = self.block13(x)
327
+ x = self.block14(x)
328
+ x = self.block15(x)
329
+ x = self.block16(x)
330
+ x = self.block17(x)
331
+ x = self.block18(x)
332
+ x = self.block19(x)
333
+
334
+ # Exit flow
335
+ x = self.block20(x)
336
+ x = self.conv3(x)
337
+ # x = self.bn3(x)
338
+ x = self.relu(x)
339
+
340
+ x = self.conv4(x)
341
+ # x = self.bn4(x)
342
+ x = self.relu(x)
343
+
344
+ x = self.conv5(x)
345
+ # x = self.bn5(x)
346
+ x = self.relu(x)
347
+
348
+ return x, low_level_feat
349
+
350
+ def __init_weight(self):
351
+ for m in self.modules():
352
+ if isinstance(m, nn.Conv2d):
353
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
354
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
355
+ torch.nn.init.kaiming_normal_(m.weight)
356
+ elif isinstance(m, nn.BatchNorm2d):
357
+ m.weight.data.fill_(1)
358
+ m.bias.data.zero_()
359
+
360
+ def __load_xception_pretrained(self):
361
+ pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
362
+ model_dict = {}
363
+ state_dict = self.state_dict()
364
+
365
+ for k, v in pretrain_dict.items():
366
+ if k in state_dict:
367
+ if 'pointwise' in k:
368
+ v = v.unsqueeze(-1).unsqueeze(-1)
369
+ if k.startswith('block12'):
370
+ model_dict[k.replace('block12', 'block20')] = v
371
+ elif k.startswith('block11'):
372
+ model_dict[k.replace('block11', 'block12')] = v
373
+ model_dict[k.replace('block11', 'block13')] = v
374
+ model_dict[k.replace('block11', 'block14')] = v
375
+ model_dict[k.replace('block11', 'block15')] = v
376
+ model_dict[k.replace('block11', 'block16')] = v
377
+ model_dict[k.replace('block11', 'block17')] = v
378
+ model_dict[k.replace('block11', 'block18')] = v
379
+ model_dict[k.replace('block11', 'block19')] = v
380
+ elif k.startswith('conv3'):
381
+ model_dict[k] = v
382
+ elif k.startswith('bn3'):
383
+ model_dict[k] = v
384
+ model_dict[k.replace('bn3', 'bn4')] = v
385
+ elif k.startswith('conv4'):
386
+ model_dict[k.replace('conv4', 'conv5')] = v
387
+ elif k.startswith('bn4'):
388
+ model_dict[k.replace('bn4', 'bn5')] = v
389
+ else:
390
+ model_dict[k] = v
391
+ state_dict.update(model_dict)
392
+ self.load_state_dict(state_dict)
393
+
394
+ class DeepLabv3_plus(nn.Module):
395
+ def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True):
396
+ if _print:
397
+ print("Constructing DeepLabv3+ model...")
398
+ print("Number of classes: {}".format(n_classes))
399
+ print("Output stride: {}".format(os))
400
+ print("Number of Input Channels: {}".format(nInputChannels))
401
+ super(DeepLabv3_plus, self).__init__()
402
+
403
+ # Atrous Conv
404
+ self.xception_features = Xception(nInputChannels, os, pretrained)
405
+
406
+ # ASPP
407
+ if os == 16:
408
+ rates = [1, 6, 12, 18]
409
+ elif os == 8:
410
+ rates = [1, 12, 24, 36]
411
+ raise NotImplementedError
412
+ else:
413
+ raise NotImplementedError
414
+
415
+ self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0])
416
+ self.aspp2 = ASPP_module(2048, 256, rate=rates[1])
417
+ self.aspp3 = ASPP_module(2048, 256, rate=rates[2])
418
+ self.aspp4 = ASPP_module(2048, 256, rate=rates[3])
419
+
420
+ self.relu = nn.ReLU()
421
+
422
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
423
+ nn.Conv2d(2048, 256, 1, stride=1, bias=False),
424
+ nn.BatchNorm2d(256),
425
+ nn.ReLU()
426
+ )
427
+
428
+ self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False)
429
+ self.concat_projection_bn1 = nn.BatchNorm2d(256)
430
+
431
+ # adopt [1x1, 48] for channel reduction.
432
+ self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False)
433
+ self.feature_projection_bn1 = nn.BatchNorm2d(48)
434
+
435
+ self.decoder = nn.Sequential(Decoder_module(304, 256),
436
+ Decoder_module(256, 256)
437
+ )
438
+ self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1)
439
+
440
+ def forward(self, input):
441
+ x, low_level_features = self.xception_features(input)
442
+ # print(x.size())
443
+ x1 = self.aspp1(x)
444
+ x2 = self.aspp2(x)
445
+ x3 = self.aspp3(x)
446
+ x4 = self.aspp4(x)
447
+ x5 = self.global_avg_pool(x)
448
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
449
+
450
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
451
+
452
+ x = self.concat_projection_conv1(x)
453
+ x = self.concat_projection_bn1(x)
454
+ x = self.relu(x)
455
+ # print(x.size())
456
+
457
+ low_level_features = self.feature_projection_conv1(low_level_features)
458
+ low_level_features = self.feature_projection_bn1(low_level_features)
459
+ low_level_features = self.relu(low_level_features)
460
+
461
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
462
+ # print(low_level_features.size())
463
+ # print(x.size())
464
+ x = torch.cat((x, low_level_features), dim=1)
465
+ x = self.decoder(x)
466
+ x = self.semantic(x)
467
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
468
+
469
+ return x
470
+
471
+ def freeze_bn(self):
472
+ for m in self.xception_features.modules():
473
+ if isinstance(m, nn.BatchNorm2d):
474
+ m.eval()
475
+
476
+ def freeze_totally_bn(self):
477
+ for m in self.modules():
478
+ if isinstance(m, nn.BatchNorm2d):
479
+ m.eval()
480
+
481
+ def freeze_aspp_bn(self):
482
+ for m in self.aspp1.modules():
483
+ if isinstance(m, nn.BatchNorm2d):
484
+ m.eval()
485
+ for m in self.aspp2.modules():
486
+ if isinstance(m, nn.BatchNorm2d):
487
+ m.eval()
488
+ for m in self.aspp3.modules():
489
+ if isinstance(m, nn.BatchNorm2d):
490
+ m.eval()
491
+ for m in self.aspp4.modules():
492
+ if isinstance(m, nn.BatchNorm2d):
493
+ m.eval()
494
+
495
+ def learnable_parameters(self):
496
+ layer_features_BN = []
497
+ layer_features = []
498
+ layer_aspp = []
499
+ layer_projection =[]
500
+ layer_decoder = []
501
+ layer_other = []
502
+ model_para = list(self.named_parameters())
503
+ for name,para in model_para:
504
+ if 'xception' in name:
505
+ if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name:
506
+ layer_features_BN.append(para)
507
+ else:
508
+ layer_features.append(para)
509
+ # print (name)
510
+ elif 'aspp' in name:
511
+ layer_aspp.append(para)
512
+ elif 'projection' in name:
513
+ layer_projection.append(para)
514
+ elif 'decode' in name:
515
+ layer_decoder.append(para)
516
+ elif 'global' not in name:
517
+ layer_other.append(para)
518
+ return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other
519
+
520
+ def get_backbone_para(self):
521
+ layer_features = []
522
+ other_features = []
523
+ model_para = list(self.named_parameters())
524
+ for name, para in model_para:
525
+ if 'xception' in name:
526
+ layer_features.append(para)
527
+ else:
528
+ other_features.append(para)
529
+
530
+ return layer_features, other_features
531
+
532
+ def train_fixbn(self, mode=True, freeze_bn=True, freeze_bn_affine=False):
533
+ r"""Sets the module in training mode.
534
+
535
+ This has any effect only on certain modules. See documentations of
536
+ particular modules for details of their behaviors in training/evaluation
537
+ mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
538
+ etc.
539
+
540
+ Returns:
541
+ Module: self
542
+ """
543
+ super(DeepLabv3_plus, self).train(mode)
544
+ if freeze_bn:
545
+ print("Freezing Mean/Var of BatchNorm2D.")
546
+ if freeze_bn_affine:
547
+ print("Freezing Weight/Bias of BatchNorm2D.")
548
+ if freeze_bn:
549
+ for m in self.xception_features.modules():
550
+ if isinstance(m, nn.BatchNorm2d):
551
+ m.eval()
552
+ if freeze_bn_affine:
553
+ m.weight.requires_grad = False
554
+ m.bias.requires_grad = False
555
+ # for m in self.aspp1.modules():
556
+ # if isinstance(m, nn.BatchNorm2d):
557
+ # m.eval()
558
+ # if freeze_bn_affine:
559
+ # m.weight.requires_grad = False
560
+ # m.bias.requires_grad = False
561
+ # for m in self.aspp2.modules():
562
+ # if isinstance(m, nn.BatchNorm2d):
563
+ # m.eval()
564
+ # if freeze_bn_affine:
565
+ # m.weight.requires_grad = False
566
+ # m.bias.requires_grad = False
567
+ # for m in self.aspp3.modules():
568
+ # if isinstance(m, nn.BatchNorm2d):
569
+ # m.eval()
570
+ # if freeze_bn_affine:
571
+ # m.weight.requires_grad = False
572
+ # m.bias.requires_grad = False
573
+ # for m in self.aspp4.modules():
574
+ # if isinstance(m, nn.BatchNorm2d):
575
+ # m.eval()
576
+ # if freeze_bn_affine:
577
+ # m.weight.requires_grad = False
578
+ # m.bias.requires_grad = False
579
+ # for m in self.global_avg_pool.modules():
580
+ # if isinstance(m, nn.BatchNorm2d):
581
+ # m.eval()
582
+ # if freeze_bn_affine:
583
+ # m.weight.requires_grad = False
584
+ # m.bias.requires_grad = False
585
+ # for m in self.concat_projection_bn1.modules():
586
+ # if isinstance(m, nn.BatchNorm2d):
587
+ # m.eval()
588
+ # if freeze_bn_affine:
589
+ # m.weight.requires_grad = False
590
+ # m.bias.requires_grad = False
591
+ # for m in self.feature_projection_bn1.modules():
592
+ # if isinstance(m, nn.BatchNorm2d):
593
+ # m.eval()
594
+ # if freeze_bn_affine:
595
+ # m.weight.requires_grad = False
596
+ # m.bias.requires_grad = False
597
+
598
+ def __init_weight(self):
599
+ for m in self.modules():
600
+ if isinstance(m, nn.Conv2d):
601
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
602
+ m.weight.data.normal_(0, math.sqrt(2. / n))
603
+ # torch.nn.init.kaiming_normal_(m.weight)
604
+ elif isinstance(m, nn.BatchNorm2d):
605
+ m.weight.data.fill_(1)
606
+ m.bias.data.zero_()
607
+
608
+ def load_state_dict_new(self, state_dict):
609
+ own_state = self.state_dict()
610
+ #for name inshop_cos own_state:
611
+ # print name
612
+ new_state_dict = OrderedDict()
613
+ for name, param in state_dict.items():
614
+ name = name.replace('module.','')
615
+ new_state_dict[name] = 0
616
+ if name not in own_state:
617
+ if 'num_batch' in name:
618
+ continue
619
+ print ('unexpected key "{}" in state_dict'
620
+ .format(name))
621
+ continue
622
+ # if isinstance(param, own_state):
623
+ if isinstance(param, Parameter):
624
+ # backwards compatibility for serialized parameters
625
+ param = param.data
626
+ try:
627
+ own_state[name].copy_(param)
628
+ except:
629
+ print('While copying the parameter named {}, whose dimensions in the model are'
630
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
631
+ name, own_state[name].size(), param.size()))
632
+ continue # i add inshop_cos 2018/02/01
633
+ # raise
634
+ # print 'copying %s' %name
635
+ # if isinstance(param, own_state):
636
+ # backwards compatibility for serialized parameters
637
+ own_state[name].copy_(param)
638
+ # print 'copying %s' %name
639
+
640
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
641
+ if len(missing) > 0:
642
+ print('missing keys in state_dict: "{}"'.format(missing))
643
+
644
+
645
+ def get_1x_lr_params(model):
646
+ """
647
+ This generator returns all the parameters of the net except for
648
+ the last classification layer. Note that for each batchnorm layer,
649
+ requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
650
+ any batchnorm parameter
651
+ """
652
+ b = [model.xception_features]
653
+ for i in range(len(b)):
654
+ for k in b[i].parameters():
655
+ if k.requires_grad:
656
+ yield k
657
+
658
+
659
+ def get_10x_lr_params(model):
660
+ """
661
+ This generator returns all the parameters for the last layer of the net,
662
+ which does the classification of pixel into classes
663
+ """
664
+ b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
665
+ for j in range(len(b)):
666
+ for k in b[j].parameters():
667
+ if k.requires_grad:
668
+ yield k
669
+
670
+
671
+ if __name__ == "__main__":
672
+ model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True)
673
+ model.eval()
674
+ image = torch.randn(1, 3, 512, 512)*255
675
+ with torch.no_grad():
676
+ output = model.forward(image)
677
+ print(output.size())
678
+ # print(output)
679
+
680
+
681
+
682
+
683
+
684
+
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_synBN.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn.parameter import Parameter
7
+ from collections import OrderedDict
8
+ from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback, SynchronizedBatchNorm2d
9
+
10
+
11
+ def fixed_padding(inputs, kernel_size, rate):
12
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
13
+ pad_total = kernel_size_effective - 1
14
+ pad_beg = pad_total // 2
15
+ pad_end = pad_total - pad_beg
16
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
17
+ return padded_inputs
18
+
19
+ class SeparableConv2d_aspp(nn.Module):
20
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
21
+ super(SeparableConv2d_aspp, self).__init__()
22
+
23
+ self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
24
+ groups=inplanes, bias=bias)
25
+ self.depthwise_bn = SynchronizedBatchNorm2d(inplanes)
26
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
27
+ self.pointwise_bn = SynchronizedBatchNorm2d(planes)
28
+ self.relu = nn.ReLU()
29
+
30
+ def forward(self, x):
31
+ # x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
32
+ x = self.depthwise(x)
33
+ x = self.depthwise_bn(x)
34
+ x = self.relu(x)
35
+ x = self.pointwise(x)
36
+ x = self.pointwise_bn(x)
37
+ x = self.relu(x)
38
+ return x
39
+
40
+ class Decoder_module(nn.Module):
41
+ def __init__(self, inplanes, planes, rate=1):
42
+ super(Decoder_module, self).__init__()
43
+ self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1)
44
+
45
+ def forward(self, x):
46
+ x = self.atrous_convolution(x)
47
+ return x
48
+
49
+ class ASPP_module(nn.Module):
50
+ def __init__(self, inplanes, planes, rate):
51
+ super(ASPP_module, self).__init__()
52
+ if rate == 1:
53
+ raise RuntimeError()
54
+ else:
55
+ kernel_size = 3
56
+ padding = rate
57
+ self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,
58
+ padding=padding)
59
+
60
+ def forward(self, x):
61
+ x = self.atrous_convolution(x)
62
+ return x
63
+
64
+
65
+ class ASPP_module_rate0(nn.Module):
66
+ def __init__(self, inplanes, planes, rate=1):
67
+ super(ASPP_module_rate0, self).__init__()
68
+ if rate == 1:
69
+ kernel_size = 1
70
+ padding = 0
71
+ self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
72
+ stride=1, padding=padding, dilation=rate, bias=False)
73
+ self.bn = SynchronizedBatchNorm2d(planes, eps=1e-5, affine=True)
74
+ self.relu = nn.ReLU()
75
+ else:
76
+ raise RuntimeError()
77
+
78
+ def forward(self, x):
79
+ x = self.atrous_convolution(x)
80
+ x = self.bn(x)
81
+ return self.relu(x)
82
+
83
+
84
+ class SeparableConv2d_same(nn.Module):
85
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
86
+ super(SeparableConv2d_same, self).__init__()
87
+
88
+ self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
89
+ groups=inplanes, bias=bias)
90
+ self.depthwise_bn = SynchronizedBatchNorm2d(inplanes)
91
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
92
+ self.pointwise_bn = SynchronizedBatchNorm2d(planes)
93
+
94
+ def forward(self, x):
95
+ x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
96
+ x = self.depthwise(x)
97
+ x = self.depthwise_bn(x)
98
+ x = self.pointwise(x)
99
+ x = self.pointwise_bn(x)
100
+ return x
101
+
102
+
103
+ class Block(nn.Module):
104
+ def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
105
+ super(Block, self).__init__()
106
+
107
+ if planes != inplanes or stride != 1:
108
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False)
109
+ if is_last:
110
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False)
111
+ self.skipbn = SynchronizedBatchNorm2d(planes)
112
+ else:
113
+ self.skip = None
114
+
115
+ self.relu = nn.ReLU(inplace=True)
116
+ rep = []
117
+
118
+ filters = inplanes
119
+ if grow_first:
120
+ rep.append(self.relu)
121
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
122
+ # rep.append(nn.BatchNorm2d(planes))
123
+ filters = planes
124
+
125
+ for i in range(reps - 1):
126
+ rep.append(self.relu)
127
+ rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
128
+ # rep.append(nn.BatchNorm2d(filters))
129
+
130
+ if not grow_first:
131
+ rep.append(self.relu)
132
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
133
+ # rep.append(nn.BatchNorm2d(planes))
134
+
135
+ if not start_with_relu:
136
+ rep = rep[1:]
137
+
138
+ if stride != 1:
139
+ rep.append(self.relu)
140
+ rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation))
141
+
142
+ if is_last:
143
+ rep.append(self.relu)
144
+ rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation))
145
+
146
+
147
+ self.rep = nn.Sequential(*rep)
148
+
149
+ def forward(self, inp):
150
+ x = self.rep(inp)
151
+
152
+ if self.skip is not None:
153
+ skip = self.skip(inp)
154
+ skip = self.skipbn(skip)
155
+ else:
156
+ skip = inp
157
+ # print(x.size(),skip.size())
158
+ x += skip
159
+
160
+ return x
161
+
162
+ class Block2(nn.Module):
163
+ def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
164
+ super(Block2, self).__init__()
165
+
166
+ if planes != inplanes or stride != 1:
167
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
168
+ self.skipbn = SynchronizedBatchNorm2d(planes)
169
+ else:
170
+ self.skip = None
171
+
172
+ self.relu = nn.ReLU(inplace=True)
173
+ rep = []
174
+
175
+ filters = inplanes
176
+ if grow_first:
177
+ rep.append(self.relu)
178
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
179
+ # rep.append(nn.BatchNorm2d(planes))
180
+ filters = planes
181
+
182
+ for i in range(reps - 1):
183
+ rep.append(self.relu)
184
+ rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
185
+ # rep.append(nn.BatchNorm2d(filters))
186
+
187
+ if not grow_first:
188
+ rep.append(self.relu)
189
+ rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
190
+ # rep.append(nn.BatchNorm2d(planes))
191
+
192
+ if not start_with_relu:
193
+ rep = rep[1:]
194
+
195
+ if stride != 1:
196
+ self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)])
197
+
198
+ if is_last:
199
+ rep.append(SeparableConv2d_same(planes, planes, 3, stride=1))
200
+
201
+
202
+ self.rep = nn.Sequential(*rep)
203
+
204
+ def forward(self, inp):
205
+ x = self.rep(inp)
206
+ low_middle = x.clone()
207
+ x1 = x
208
+ x1 = self.block2_lastconv(x1)
209
+ if self.skip is not None:
210
+ skip = self.skip(inp)
211
+ skip = self.skipbn(skip)
212
+ else:
213
+ skip = inp
214
+
215
+ x1 += skip
216
+
217
+ return x1,low_middle
218
+
219
+ class Xception(nn.Module):
220
+ """
221
+ Modified Alighed Xception
222
+ """
223
+ def __init__(self, inplanes=3, os=16, pretrained=False):
224
+ super(Xception, self).__init__()
225
+
226
+ if os == 16:
227
+ entry_block3_stride = 2
228
+ middle_block_rate = 1
229
+ exit_block_rates = (1, 2)
230
+ elif os == 8:
231
+ entry_block3_stride = 1
232
+ middle_block_rate = 2
233
+ exit_block_rates = (2, 4)
234
+ else:
235
+ raise NotImplementedError
236
+
237
+
238
+ # Entry flow
239
+ self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False)
240
+ self.bn1 = SynchronizedBatchNorm2d(32)
241
+ self.relu = nn.ReLU(inplace=True)
242
+
243
+ self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
244
+ self.bn2 = SynchronizedBatchNorm2d(64)
245
+
246
+ self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)
247
+ self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
248
+ self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True)
249
+
250
+ # Middle flow
251
+ self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
252
+ self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
253
+ self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
254
+ self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
255
+ self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
256
+ self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
257
+ self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
258
+ self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
259
+ self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
260
+ self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
261
+ self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
262
+ self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
263
+ self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
264
+ self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
265
+ self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
266
+ self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
267
+
268
+ # Exit flow
269
+ self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0],
270
+ start_with_relu=True, grow_first=False, is_last=True)
271
+
272
+ self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
273
+ # self.bn3 = nn.BatchNorm2d(1536)
274
+
275
+ self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
276
+ # self.bn4 = nn.BatchNorm2d(1536)
277
+
278
+ self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
279
+ # self.bn5 = nn.BatchNorm2d(2048)
280
+
281
+ # Init weights
282
+ # self.__init_weight()
283
+
284
+ # Load pretrained model
285
+ if pretrained:
286
+ self.__load_xception_pretrained()
287
+
288
+ def forward(self, x):
289
+ # Entry flow
290
+ x = self.conv1(x)
291
+ x = self.bn1(x)
292
+ x = self.relu(x)
293
+ # print('conv1 ',x.size())
294
+ x = self.conv2(x)
295
+ x = self.bn2(x)
296
+ x = self.relu(x)
297
+
298
+ x = self.block1(x)
299
+ # print('block1',x.size())
300
+ # low_level_feat = x
301
+ x,low_level_feat = self.block2(x)
302
+ # print('block2',x.size())
303
+ x = self.block3(x)
304
+ # print('xception block3 ',x.size())
305
+
306
+ # Middle flow
307
+ x = self.block4(x)
308
+ x = self.block5(x)
309
+ x = self.block6(x)
310
+ x = self.block7(x)
311
+ x = self.block8(x)
312
+ x = self.block9(x)
313
+ x = self.block10(x)
314
+ x = self.block11(x)
315
+ x = self.block12(x)
316
+ x = self.block13(x)
317
+ x = self.block14(x)
318
+ x = self.block15(x)
319
+ x = self.block16(x)
320
+ x = self.block17(x)
321
+ x = self.block18(x)
322
+ x = self.block19(x)
323
+
324
+ # Exit flow
325
+ x = self.block20(x)
326
+ x = self.conv3(x)
327
+ # x = self.bn3(x)
328
+ x = self.relu(x)
329
+
330
+ x = self.conv4(x)
331
+ # x = self.bn4(x)
332
+ x = self.relu(x)
333
+
334
+ x = self.conv5(x)
335
+ # x = self.bn5(x)
336
+ x = self.relu(x)
337
+
338
+ return x, low_level_feat
339
+
340
+ def __init_weight(self):
341
+ for m in self.modules():
342
+ if isinstance(m, nn.Conv2d):
343
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
344
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
345
+ torch.nn.init.kaiming_normal_(m.weight)
346
+ elif isinstance(m, nn.BatchNorm2d):
347
+ m.weight.data.fill_(1)
348
+ m.bias.data.zero_()
349
+
350
+ def __load_xception_pretrained(self):
351
+ pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
352
+ model_dict = {}
353
+ state_dict = self.state_dict()
354
+
355
+ for k, v in pretrain_dict.items():
356
+ if k in state_dict:
357
+ if 'pointwise' in k:
358
+ v = v.unsqueeze(-1).unsqueeze(-1)
359
+ if k.startswith('block12'):
360
+ model_dict[k.replace('block12', 'block20')] = v
361
+ elif k.startswith('block11'):
362
+ model_dict[k.replace('block11', 'block12')] = v
363
+ model_dict[k.replace('block11', 'block13')] = v
364
+ model_dict[k.replace('block11', 'block14')] = v
365
+ model_dict[k.replace('block11', 'block15')] = v
366
+ model_dict[k.replace('block11', 'block16')] = v
367
+ model_dict[k.replace('block11', 'block17')] = v
368
+ model_dict[k.replace('block11', 'block18')] = v
369
+ model_dict[k.replace('block11', 'block19')] = v
370
+ elif k.startswith('conv3'):
371
+ model_dict[k] = v
372
+ elif k.startswith('bn3'):
373
+ model_dict[k] = v
374
+ model_dict[k.replace('bn3', 'bn4')] = v
375
+ elif k.startswith('conv4'):
376
+ model_dict[k.replace('conv4', 'conv5')] = v
377
+ elif k.startswith('bn4'):
378
+ model_dict[k.replace('bn4', 'bn5')] = v
379
+ else:
380
+ model_dict[k] = v
381
+ state_dict.update(model_dict)
382
+ self.load_state_dict(state_dict)
383
+
384
+ class DeepLabv3_plus(nn.Module):
385
+ def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True):
386
+ if _print:
387
+ print("Constructing DeepLabv3+ model...")
388
+ print("Number of classes: {}".format(n_classes))
389
+ print("Output stride: {}".format(os))
390
+ print("Number of Input Channels: {}".format(nInputChannels))
391
+ super(DeepLabv3_plus, self).__init__()
392
+
393
+ # Atrous Conv
394
+ self.xception_features = Xception(nInputChannels, os, pretrained)
395
+
396
+ # ASPP
397
+ if os == 16:
398
+ rates = [1, 6, 12, 18]
399
+ elif os == 8:
400
+ rates = [1, 12, 24, 36]
401
+ else:
402
+ raise NotImplementedError
403
+
404
+ self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0])
405
+ self.aspp2 = ASPP_module(2048, 256, rate=rates[1])
406
+ self.aspp3 = ASPP_module(2048, 256, rate=rates[2])
407
+ self.aspp4 = ASPP_module(2048, 256, rate=rates[3])
408
+
409
+ self.relu = nn.ReLU()
410
+
411
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
412
+ nn.Conv2d(2048, 256, 1, stride=1, bias=False),
413
+ SynchronizedBatchNorm2d(256),
414
+ nn.ReLU()
415
+ )
416
+
417
+ self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False)
418
+ self.concat_projection_bn1 = SynchronizedBatchNorm2d(256)
419
+
420
+ # adopt [1x1, 48] for channel reduction.
421
+ self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False)
422
+ self.feature_projection_bn1 = SynchronizedBatchNorm2d(48)
423
+
424
+ self.decoder = nn.Sequential(Decoder_module(304, 256),
425
+ Decoder_module(256, 256)
426
+ )
427
+ self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1)
428
+
429
+ def forward(self, input):
430
+ x, low_level_features = self.xception_features(input)
431
+ # print(x.size())
432
+ x1 = self.aspp1(x)
433
+ x2 = self.aspp2(x)
434
+ x3 = self.aspp3(x)
435
+ x4 = self.aspp4(x)
436
+ x5 = self.global_avg_pool(x)
437
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
438
+
439
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
440
+
441
+ x = self.concat_projection_conv1(x)
442
+ x = self.concat_projection_bn1(x)
443
+ x = self.relu(x)
444
+ # print(x.size())
445
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
446
+
447
+ low_level_features = self.feature_projection_conv1(low_level_features)
448
+ low_level_features = self.feature_projection_bn1(low_level_features)
449
+ low_level_features = self.relu(low_level_features)
450
+ # print(low_level_features.size())
451
+ # print(x.size())
452
+ x = torch.cat((x, low_level_features), dim=1)
453
+ x = self.decoder(x)
454
+ x = self.semantic(x)
455
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
456
+
457
+ return x
458
+
459
+ def freeze_bn(self):
460
+ for m in self.xception_features.modules():
461
+ if isinstance(m, nn.BatchNorm2d) or isinstance(m,SynchronizedBatchNorm2d):
462
+ m.eval()
463
+
464
+ def freeze_aspp_bn(self):
465
+ for m in self.aspp1.modules():
466
+ if isinstance(m, nn.BatchNorm2d):
467
+ m.eval()
468
+ for m in self.aspp2.modules():
469
+ if isinstance(m, nn.BatchNorm2d):
470
+ m.eval()
471
+ for m in self.aspp3.modules():
472
+ if isinstance(m, nn.BatchNorm2d):
473
+ m.eval()
474
+ for m in self.aspp4.modules():
475
+ if isinstance(m, nn.BatchNorm2d):
476
+ m.eval()
477
+
478
+ def learnable_parameters(self):
479
+ layer_features_BN = []
480
+ layer_features = []
481
+ layer_aspp = []
482
+ layer_projection =[]
483
+ layer_decoder = []
484
+ layer_other = []
485
+ model_para = list(self.named_parameters())
486
+ for name,para in model_para:
487
+ if 'xception' in name:
488
+ if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name:
489
+ layer_features_BN.append(para)
490
+ else:
491
+ layer_features.append(para)
492
+ # print (name)
493
+ elif 'aspp' in name:
494
+ layer_aspp.append(para)
495
+ elif 'projection' in name:
496
+ layer_projection.append(para)
497
+ elif 'decode' in name:
498
+ layer_decoder.append(para)
499
+ else:
500
+ layer_other.append(para)
501
+ return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other
502
+
503
+
504
+ def __init_weight(self):
505
+ for m in self.modules():
506
+ if isinstance(m, nn.Conv2d):
507
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
508
+ m.weight.data.normal_(0, math.sqrt(2. / n))
509
+ # torch.nn.init.kaiming_normal_(m.weight)
510
+ elif isinstance(m, nn.BatchNorm2d):
511
+ m.weight.data.fill_(1)
512
+ m.bias.data.zero_()
513
+
514
+ def load_state_dict_new(self, state_dict):
515
+ own_state = self.state_dict()
516
+ #for name inshop_cos own_state:
517
+ # print name
518
+ new_state_dict = OrderedDict()
519
+ for name, param in state_dict.items():
520
+ name = name.replace('module.','')
521
+ new_state_dict[name] = 0
522
+ if name not in own_state:
523
+ if 'num_batch' in name:
524
+ continue
525
+ print ('unexpected key "{}" in state_dict'
526
+ .format(name))
527
+ continue
528
+ # if isinstance(param, own_state):
529
+ if isinstance(param, Parameter):
530
+ # backwards compatibility for serialized parameters
531
+ param = param.data
532
+ try:
533
+ own_state[name].copy_(param)
534
+ except:
535
+ print('While copying the parameter named {}, whose dimensions in the model are'
536
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
537
+ name, own_state[name].size(), param.size()))
538
+ continue # i add inshop_cos 2018/02/01
539
+ # raise
540
+ # print 'copying %s' %name
541
+ # if isinstance(param, own_state):
542
+ # backwards compatibility for serialized parameters
543
+ own_state[name].copy_(param)
544
+ # print 'copying %s' %name
545
+
546
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
547
+ if len(missing) > 0:
548
+ print('missing keys in state_dict: "{}"'.format(missing))
549
+
550
+
551
+
552
+
553
+ def get_1x_lr_params(model):
554
+ """
555
+ This generator returns all the parameters of the net except for
556
+ the last classification layer. Note that for each batchnorm layer,
557
+ requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
558
+ any batchnorm parameter
559
+ """
560
+ b = [model.xception_features]
561
+ for i in range(len(b)):
562
+ for k in b[i].parameters():
563
+ if k.requires_grad:
564
+ yield k
565
+
566
+
567
+ def get_10x_lr_params(model):
568
+ """
569
+ This generator returns all the parameters for the last layer of the net,
570
+ which does the classification of pixel into classes
571
+ """
572
+ b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
573
+ for j in range(len(b)):
574
+ for k in b[j].parameters():
575
+ if k.requires_grad:
576
+ yield k
577
+
578
+
579
+ if __name__ == "__main__":
580
+ model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True)
581
+ model.eval()
582
+ # ckt = torch.load('C:\\Users\gaoyi\code_python\deeplab_v3plus.pth')
583
+ # model.load_state_dict_new(ckt)
584
+
585
+
586
+ image = torch.randn(1, 3, 512, 512)*255
587
+ with torch.no_grad():
588
+ output = model.forward(image)
589
+ print(output.size())
590
+ # print(output)
591
+
592
+
593
+
594
+
595
+
596
+
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_transfer.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn.parameter import Parameter
7
+ import numpy as np
8
+ from collections import OrderedDict
9
+ from torch.nn import Parameter
10
+ from networks import deeplab_xception,gcn, deeplab_xception_synBN
11
+ import pdb
12
+
13
+ #######################
14
+ # base model
15
+ #######################
16
+
17
+ class deeplab_xception_transfer_basemodel(deeplab_xception.DeepLabv3_plus):
18
+ def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
19
+ super(deeplab_xception_transfer_basemodel, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
20
+ os=os,)
21
+ ### source graph
22
+ # self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
23
+ # nodes=n_classes)
24
+ # self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
25
+ # self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
26
+ # self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
27
+ #
28
+ # self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
29
+ # hidden_layers=hidden_layers, nodes=n_classes
30
+ # )
31
+ # self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
32
+ # nn.ReLU(True)])
33
+
34
+ ### target graph
35
+ self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
36
+ nodes=n_classes)
37
+ self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
38
+ self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
39
+ self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
40
+
41
+ self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
42
+ hidden_layers=hidden_layers, nodes=n_classes
43
+ )
44
+ self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
45
+ nn.ReLU(True)])
46
+
47
+ def load_source_model(self,state_dict):
48
+ own_state = self.state_dict()
49
+ # for name inshop_cos own_state:
50
+ # print name
51
+ new_state_dict = OrderedDict()
52
+ for name, param in state_dict.items():
53
+ name = name.replace('module.', '')
54
+ if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name:
55
+ if 'featuremap_2_graph' in name:
56
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
57
+ else:
58
+ name = name.replace('graph','source_graph')
59
+ new_state_dict[name] = 0
60
+ if name not in own_state:
61
+ if 'num_batch' in name:
62
+ continue
63
+ print('unexpected key "{}" in state_dict'
64
+ .format(name))
65
+ continue
66
+ # if isinstance(param, own_state):
67
+ if isinstance(param, Parameter):
68
+ # backwards compatibility for serialized parameters
69
+ param = param.data
70
+ try:
71
+ own_state[name].copy_(param)
72
+ except:
73
+ print('While copying the parameter named {}, whose dimensions in the model are'
74
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
75
+ name, own_state[name].size(), param.size()))
76
+ continue # i add inshop_cos 2018/02/01
77
+ own_state[name].copy_(param)
78
+ # print 'copying %s' %name
79
+
80
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
81
+ if len(missing) > 0:
82
+ print('missing keys in state_dict: "{}"'.format(missing))
83
+
84
+ def get_target_parameter(self):
85
+ l = []
86
+ other = []
87
+ for name, k in self.named_parameters():
88
+ if 'target' in name or 'semantic' in name:
89
+ l.append(k)
90
+ else:
91
+ other.append(k)
92
+ return l, other
93
+
94
+ def get_semantic_parameter(self):
95
+ l = []
96
+ for name, k in self.named_parameters():
97
+ if 'semantic' in name:
98
+ l.append(k)
99
+ return l
100
+
101
+ def get_source_parameter(self):
102
+ l = []
103
+ for name, k in self.named_parameters():
104
+ if 'source' in name:
105
+ l.append(k)
106
+ return l
107
+
108
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
109
+ x, low_level_features = self.xception_features(input)
110
+ # print(x.size())
111
+ x1 = self.aspp1(x)
112
+ x2 = self.aspp2(x)
113
+ x3 = self.aspp3(x)
114
+ x4 = self.aspp4(x)
115
+ x5 = self.global_avg_pool(x)
116
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
117
+
118
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
119
+
120
+ x = self.concat_projection_conv1(x)
121
+ x = self.concat_projection_bn1(x)
122
+ x = self.relu(x)
123
+ # print(x.size())
124
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
125
+
126
+ low_level_features = self.feature_projection_conv1(low_level_features)
127
+ low_level_features = self.feature_projection_bn1(low_level_features)
128
+ low_level_features = self.relu(low_level_features)
129
+ # print(low_level_features.size())
130
+ # print(x.size())
131
+ x = torch.cat((x, low_level_features), dim=1)
132
+ x = self.decoder(x)
133
+
134
+ ### add graph
135
+
136
+
137
+ # target graph
138
+ # print('x size',x.size(),adj1.size())
139
+ graph = self.target_featuremap_2_graph(x)
140
+
141
+ # graph combine
142
+ # print(graph.size(),source_2_target_graph.size())
143
+ # graph = self.fc_graph.forward(graph,relu=True)
144
+ # print(graph.size())
145
+
146
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
147
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
148
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
149
+ # print(graph.size(),x.size())
150
+ # graph = self.gcn_encode.forward(graph,relu=True)
151
+ # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
152
+ # graph = self.gcn_decode.forward(graph,relu=True)
153
+ graph = self.target_graph_2_fea.forward(graph, x)
154
+ x = self.target_skip_conv(x)
155
+ x = x + graph
156
+
157
+ ###
158
+ x = self.semantic(x)
159
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
160
+
161
+ return x
162
+
163
+ class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus):
164
+ def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
165
+ super(deeplab_xception_transfer_basemodel_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
166
+ os=os,)
167
+ ### source graph
168
+
169
+ ### target graph
170
+ self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
171
+ nodes=n_classes)
172
+ self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
173
+ self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
174
+ self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
175
+
176
+ self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels,
177
+ hidden_layers=hidden_layers, nodes=n_classes
178
+ )
179
+ self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
180
+ nn.ReLU(True)])
181
+
182
+ def load_source_model(self,state_dict):
183
+ own_state = self.state_dict()
184
+ # for name inshop_cos own_state:
185
+ # print name
186
+ new_state_dict = OrderedDict()
187
+ for name, param in state_dict.items():
188
+ name = name.replace('module.', '')
189
+ if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name:
190
+ if 'featuremap_2_graph' in name:
191
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
192
+ else:
193
+ name = name.replace('graph','source_graph')
194
+ new_state_dict[name] = 0
195
+ if name not in own_state:
196
+ if 'num_batch' in name:
197
+ continue
198
+ print('unexpected key "{}" in state_dict'
199
+ .format(name))
200
+ continue
201
+ # if isinstance(param, own_state):
202
+ if isinstance(param, Parameter):
203
+ # backwards compatibility for serialized parameters
204
+ param = param.data
205
+ try:
206
+ own_state[name].copy_(param)
207
+ except:
208
+ print('While copying the parameter named {}, whose dimensions in the model are'
209
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
210
+ name, own_state[name].size(), param.size()))
211
+ continue # i add inshop_cos 2018/02/01
212
+ own_state[name].copy_(param)
213
+ # print 'copying %s' %name
214
+
215
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
216
+ if len(missing) > 0:
217
+ print('missing keys in state_dict: "{}"'.format(missing))
218
+
219
+ def get_target_parameter(self):
220
+ l = []
221
+ other = []
222
+ for name, k in self.named_parameters():
223
+ if 'target' in name or 'semantic' in name:
224
+ l.append(k)
225
+ else:
226
+ other.append(k)
227
+ return l, other
228
+
229
+ def get_semantic_parameter(self):
230
+ l = []
231
+ for name, k in self.named_parameters():
232
+ if 'semantic' in name:
233
+ l.append(k)
234
+ return l
235
+
236
+ def get_source_parameter(self):
237
+ l = []
238
+ for name, k in self.named_parameters():
239
+ if 'source' in name:
240
+ l.append(k)
241
+ return l
242
+
243
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
244
+ x, low_level_features = self.xception_features(input)
245
+ # print(x.size())
246
+ x1 = self.aspp1(x)
247
+ x2 = self.aspp2(x)
248
+ x3 = self.aspp3(x)
249
+ x4 = self.aspp4(x)
250
+ x5 = self.global_avg_pool(x)
251
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
252
+
253
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
254
+
255
+ x = self.concat_projection_conv1(x)
256
+ x = self.concat_projection_bn1(x)
257
+ x = self.relu(x)
258
+ # print(x.size())
259
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
260
+
261
+ low_level_features = self.feature_projection_conv1(low_level_features)
262
+ low_level_features = self.feature_projection_bn1(low_level_features)
263
+ low_level_features = self.relu(low_level_features)
264
+ # print(low_level_features.size())
265
+ # print(x.size())
266
+ x = torch.cat((x, low_level_features), dim=1)
267
+ x = self.decoder(x)
268
+
269
+ ### add graph
270
+
271
+
272
+ # target graph
273
+ # print('x size',x.size(),adj1.size())
274
+ graph = self.target_featuremap_2_graph(x)
275
+
276
+ # graph combine
277
+ # print(graph.size(),source_2_target_graph.size())
278
+ # graph = self.fc_graph.forward(graph,relu=True)
279
+ # print(graph.size())
280
+
281
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
282
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
283
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
284
+ # print(graph.size(),x.size())
285
+ # graph = self.gcn_encode.forward(graph,relu=True)
286
+ # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
287
+ # graph = self.gcn_decode.forward(graph,relu=True)
288
+ graph = self.target_graph_2_fea.forward(graph, x)
289
+ x = self.target_skip_conv(x)
290
+ x = x + graph
291
+
292
+ ###
293
+ x = self.semantic(x)
294
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
295
+
296
+ return x
297
+
298
+ class deeplab_xception_transfer_basemodel_synBN(deeplab_xception_synBN.DeepLabv3_plus):
299
+ def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
300
+ super(deeplab_xception_transfer_basemodel_synBN, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
301
+ os=os,)
302
+ ### source graph
303
+ # self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
304
+ # nodes=n_classes)
305
+ # self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
306
+ # self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
307
+ # self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
308
+ #
309
+ # self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
310
+ # hidden_layers=hidden_layers, nodes=n_classes
311
+ # )
312
+ # self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
313
+ # nn.ReLU(True)])
314
+
315
+ ### target graph
316
+ self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
317
+ nodes=n_classes)
318
+ self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
319
+ self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
320
+ self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
321
+
322
+ self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
323
+ hidden_layers=hidden_layers, nodes=n_classes
324
+ )
325
+ self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
326
+ nn.ReLU(True)])
327
+
328
+ def load_source_model(self,state_dict):
329
+ own_state = self.state_dict()
330
+ # for name inshop_cos own_state:
331
+ # print name
332
+ new_state_dict = OrderedDict()
333
+ for name, param in state_dict.items():
334
+ name = name.replace('module.', '')
335
+
336
+ if 'graph' in name and 'source' not in name and 'target' not in name:
337
+ if 'featuremap_2_graph' in name:
338
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
339
+ else:
340
+ name = name.replace('graph','source_graph')
341
+ new_state_dict[name] = 0
342
+ if name not in own_state:
343
+ if 'num_batch' in name:
344
+ continue
345
+ print('unexpected key "{}" in state_dict'
346
+ .format(name))
347
+ continue
348
+ # if isinstance(param, own_state):
349
+ if isinstance(param, Parameter):
350
+ # backwards compatibility for serialized parameters
351
+ param = param.data
352
+ try:
353
+ own_state[name].copy_(param)
354
+ except:
355
+ print('While copying the parameter named {}, whose dimensions in the model are'
356
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
357
+ name, own_state[name].size(), param.size()))
358
+ continue # i add inshop_cos 2018/02/01
359
+ own_state[name].copy_(param)
360
+ # print 'copying %s' %name
361
+
362
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
363
+ if len(missing) > 0:
364
+ print('missing keys in state_dict: "{}"'.format(missing))
365
+
366
+ def get_target_parameter(self):
367
+ l = []
368
+ other = []
369
+ for name, k in self.named_parameters():
370
+ if 'target' in name or 'semantic' in name:
371
+ l.append(k)
372
+ else:
373
+ other.append(k)
374
+ return l, other
375
+
376
+ def get_semantic_parameter(self):
377
+ l = []
378
+ for name, k in self.named_parameters():
379
+ if 'semantic' in name:
380
+ l.append(k)
381
+ return l
382
+
383
+ def get_source_parameter(self):
384
+ l = []
385
+ for name, k in self.named_parameters():
386
+ if 'source' in name:
387
+ l.append(k)
388
+ return l
389
+
390
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
391
+ x, low_level_features = self.xception_features(input)
392
+ # print(x.size())
393
+ x1 = self.aspp1(x)
394
+ x2 = self.aspp2(x)
395
+ x3 = self.aspp3(x)
396
+ x4 = self.aspp4(x)
397
+ x5 = self.global_avg_pool(x)
398
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
399
+
400
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
401
+
402
+ x = self.concat_projection_conv1(x)
403
+ x = self.concat_projection_bn1(x)
404
+ x = self.relu(x)
405
+ # print(x.size())
406
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
407
+
408
+ low_level_features = self.feature_projection_conv1(low_level_features)
409
+ low_level_features = self.feature_projection_bn1(low_level_features)
410
+ low_level_features = self.relu(low_level_features)
411
+ # print(low_level_features.size())
412
+ # print(x.size())
413
+ x = torch.cat((x, low_level_features), dim=1)
414
+ x = self.decoder(x)
415
+
416
+ ### add graph
417
+
418
+
419
+ # target graph
420
+ # print('x size',x.size(),adj1.size())
421
+ graph = self.target_featuremap_2_graph(x)
422
+
423
+ # graph combine
424
+ # print(graph.size(),source_2_target_graph.size())
425
+ # graph = self.fc_graph.forward(graph,relu=True)
426
+ # print(graph.size())
427
+
428
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
429
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
430
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
431
+ # print(graph.size(),x.size())
432
+ # graph = self.gcn_encode.forward(graph,relu=True)
433
+ # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
434
+ # graph = self.gcn_decode.forward(graph,relu=True)
435
+ graph = self.target_graph_2_fea.forward(graph, x)
436
+ x = self.target_skip_conv(x)
437
+ x = x + graph
438
+
439
+ ###
440
+ x = self.semantic(x)
441
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
442
+
443
+ return x
444
+
445
+ class deeplab_xception_transfer_basemodel_synBN_savememory(deeplab_xception_synBN.DeepLabv3_plus):
446
+ def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
447
+ super(deeplab_xception_transfer_basemodel_synBN_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
448
+ os=os, )
449
+ ### source graph
450
+ # self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
451
+ # nodes=n_classes)
452
+ # self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
453
+ # self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
454
+ # self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
455
+ #
456
+ # self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
457
+ # hidden_layers=hidden_layers, nodes=n_classes
458
+ # )
459
+ # self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
460
+ # nn.ReLU(True)])
461
+
462
+ ### target graph
463
+ self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
464
+ nodes=n_classes)
465
+ self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
466
+ self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
467
+ self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
468
+
469
+ self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels,
470
+ hidden_layers=hidden_layers, nodes=n_classes
471
+ )
472
+ self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
473
+ nn.BatchNorm2d(input_channels),
474
+ nn.ReLU(True)])
475
+
476
+ def load_source_model(self,state_dict):
477
+ own_state = self.state_dict()
478
+ # for name inshop_cos own_state:
479
+ # print name
480
+ new_state_dict = OrderedDict()
481
+ for name, param in state_dict.items():
482
+ name = name.replace('module.', '')
483
+
484
+ if 'graph' in name and 'source' not in name and 'target' not in name:
485
+ if 'featuremap_2_graph' in name:
486
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
487
+ else:
488
+ name = name.replace('graph','source_graph')
489
+ new_state_dict[name] = 0
490
+ if name not in own_state:
491
+ if 'num_batch' in name:
492
+ continue
493
+ print('unexpected key "{}" in state_dict'
494
+ .format(name))
495
+ continue
496
+ # if isinstance(param, own_state):
497
+ if isinstance(param, Parameter):
498
+ # backwards compatibility for serialized parameters
499
+ param = param.data
500
+ try:
501
+ own_state[name].copy_(param)
502
+ except:
503
+ print('While copying the parameter named {}, whose dimensions in the model are'
504
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
505
+ name, own_state[name].size(), param.size()))
506
+ continue # i add inshop_cos 2018/02/01
507
+ own_state[name].copy_(param)
508
+ # print 'copying %s' %name
509
+
510
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
511
+ if len(missing) > 0:
512
+ print('missing keys in state_dict: "{}"'.format(missing))
513
+
514
+ def get_target_parameter(self):
515
+ l = []
516
+ other = []
517
+ for name, k in self.named_parameters():
518
+ if 'target' in name or 'semantic' in name:
519
+ l.append(k)
520
+ else:
521
+ other.append(k)
522
+ return l, other
523
+
524
+ def get_semantic_parameter(self):
525
+ l = []
526
+ for name, k in self.named_parameters():
527
+ if 'semantic' in name:
528
+ l.append(k)
529
+ return l
530
+
531
+ def get_source_parameter(self):
532
+ l = []
533
+ for name, k in self.named_parameters():
534
+ if 'source' in name:
535
+ l.append(k)
536
+ return l
537
+
538
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
539
+ x, low_level_features = self.xception_features(input)
540
+ # print(x.size())
541
+ x1 = self.aspp1(x)
542
+ x2 = self.aspp2(x)
543
+ x3 = self.aspp3(x)
544
+ x4 = self.aspp4(x)
545
+ x5 = self.global_avg_pool(x)
546
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
547
+
548
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
549
+
550
+ x = self.concat_projection_conv1(x)
551
+ x = self.concat_projection_bn1(x)
552
+ x = self.relu(x)
553
+ # print(x.size())
554
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
555
+
556
+ low_level_features = self.feature_projection_conv1(low_level_features)
557
+ low_level_features = self.feature_projection_bn1(low_level_features)
558
+ low_level_features = self.relu(low_level_features)
559
+ # print(low_level_features.size())
560
+ # print(x.size())
561
+ x = torch.cat((x, low_level_features), dim=1)
562
+ x = self.decoder(x)
563
+
564
+ ### add graph
565
+
566
+
567
+ # target graph
568
+ # print('x size',x.size(),adj1.size())
569
+ graph = self.target_featuremap_2_graph(x)
570
+
571
+ # graph combine
572
+ # print(graph.size(),source_2_target_graph.size())
573
+ # graph = self.fc_graph.forward(graph,relu=True)
574
+ # print(graph.size())
575
+
576
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
577
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
578
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
579
+ # print(graph.size(),x.size())
580
+ # graph = self.gcn_encode.forward(graph,relu=True)
581
+ # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
582
+ # graph = self.gcn_decode.forward(graph,relu=True)
583
+ graph = self.target_graph_2_fea.forward(graph, x)
584
+ x = self.target_skip_conv(x)
585
+ x = x + graph
586
+
587
+ ###
588
+ x = self.semantic(x)
589
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
590
+
591
+ return x
592
+
593
+ #######################
594
+ # transfer model
595
+ #######################
596
+
597
+ class deeplab_xception_transfer_projection(deeplab_xception_transfer_basemodel):
598
+ def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
599
+ transfer_graph=None, source_classes=20):
600
+ super(deeplab_xception_transfer_projection, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
601
+ os=os, input_channels=input_channels,
602
+ hidden_layers=hidden_layers, out_channels=out_channels, )
603
+ self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
604
+ nodes=source_classes)
605
+ self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
606
+ self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
607
+ self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
608
+ self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
609
+ begin_nodes=source_classes,end_nodes=n_classes)
610
+ self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers)
611
+
612
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
613
+ x, low_level_features = self.xception_features(input)
614
+ # print(x.size())
615
+ x1 = self.aspp1(x)
616
+ x2 = self.aspp2(x)
617
+ x3 = self.aspp3(x)
618
+ x4 = self.aspp4(x)
619
+ x5 = self.global_avg_pool(x)
620
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
621
+
622
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
623
+
624
+ x = self.concat_projection_conv1(x)
625
+ x = self.concat_projection_bn1(x)
626
+ x = self.relu(x)
627
+ # print(x.size())
628
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
629
+
630
+ low_level_features = self.feature_projection_conv1(low_level_features)
631
+ low_level_features = self.feature_projection_bn1(low_level_features)
632
+ low_level_features = self.relu(low_level_features)
633
+ # print(low_level_features.size())
634
+ # print(x.size())
635
+ x = torch.cat((x, low_level_features), dim=1)
636
+ x = self.decoder(x)
637
+
638
+ ### add graph
639
+ # source graph
640
+ source_graph = self.source_featuremap_2_graph(x)
641
+ source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
642
+ source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
643
+ source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
644
+
645
+ source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
646
+ source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
647
+ source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
648
+
649
+ # target graph
650
+ # print('x size',x.size(),adj1.size())
651
+ graph = self.target_featuremap_2_graph(x)
652
+
653
+ source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
654
+ # graph combine 1
655
+ # print(graph.size())
656
+ # print(source_2_target_graph1.size())
657
+ # print(source_2_target_graph1_v5.size())
658
+ graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
659
+ graph = self.fc_graph.forward(graph,relu=True)
660
+
661
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
662
+
663
+ source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
664
+ # graph combine 2
665
+ graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
666
+ graph = self.fc_graph.forward(graph, relu=True)
667
+
668
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
669
+
670
+ source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
671
+ # graph combine 3
672
+ graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
673
+ graph = self.fc_graph.forward(graph, relu=True)
674
+
675
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
676
+
677
+ # print(graph.size(),x.size())
678
+
679
+ graph = self.target_graph_2_fea.forward(graph, x)
680
+ x = self.target_skip_conv(x)
681
+ x = x + graph
682
+
683
+ ###
684
+ x = self.semantic(x)
685
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
686
+
687
+ return x
688
+
689
+ def similarity_trans(self,source,target):
690
+ sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
691
+ sim = F.softmax(sim, dim=-1)
692
+ return torch.matmul(sim, source)
693
+
694
+ def load_source_model(self,state_dict):
695
+ own_state = self.state_dict()
696
+ # for name inshop_cos own_state:
697
+ # print name
698
+ new_state_dict = OrderedDict()
699
+ for name, param in state_dict.items():
700
+ name = name.replace('module.', '')
701
+
702
+ if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
703
+ if 'featuremap_2_graph' in name:
704
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
705
+ else:
706
+ name = name.replace('graph','source_graph')
707
+ new_state_dict[name] = 0
708
+ if name not in own_state:
709
+ if 'num_batch' in name:
710
+ continue
711
+ print('unexpected key "{}" in state_dict'
712
+ .format(name))
713
+ continue
714
+ # if isinstance(param, own_state):
715
+ if isinstance(param, Parameter):
716
+ # backwards compatibility for serialized parameters
717
+ param = param.data
718
+ try:
719
+ own_state[name].copy_(param)
720
+ except:
721
+ print('While copying the parameter named {}, whose dimensions in the model are'
722
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
723
+ name, own_state[name].size(), param.size()))
724
+ continue # i add inshop_cos 2018/02/01
725
+ own_state[name].copy_(param)
726
+ # print 'copying %s' %name
727
+
728
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
729
+ if len(missing) > 0:
730
+ print('missing keys in state_dict: "{}"'.format(missing))
731
+
732
+ class deeplab_xception_transfer_projection_savemem(deeplab_xception_transfer_basemodel_savememory):
733
+ def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
734
+ transfer_graph=None, source_classes=20):
735
+ super(deeplab_xception_transfer_projection_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
736
+ os=os, input_channels=input_channels,
737
+ hidden_layers=hidden_layers, out_channels=out_channels, )
738
+ self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
739
+ nodes=source_classes)
740
+ self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
741
+ self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
742
+ self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
743
+ self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
744
+ begin_nodes=source_classes,end_nodes=n_classes)
745
+ self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers)
746
+
747
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
748
+ x, low_level_features = self.xception_features(input)
749
+ # print(x.size())
750
+ x1 = self.aspp1(x)
751
+ x2 = self.aspp2(x)
752
+ x3 = self.aspp3(x)
753
+ x4 = self.aspp4(x)
754
+ x5 = self.global_avg_pool(x)
755
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
756
+
757
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
758
+
759
+ x = self.concat_projection_conv1(x)
760
+ x = self.concat_projection_bn1(x)
761
+ x = self.relu(x)
762
+ # print(x.size())
763
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
764
+
765
+ low_level_features = self.feature_projection_conv1(low_level_features)
766
+ low_level_features = self.feature_projection_bn1(low_level_features)
767
+ low_level_features = self.relu(low_level_features)
768
+ # print(low_level_features.size())
769
+ # print(x.size())
770
+ x = torch.cat((x, low_level_features), dim=1)
771
+ x = self.decoder(x)
772
+
773
+ ### add graph
774
+ # source graph
775
+ source_graph = self.source_featuremap_2_graph(x)
776
+ source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
777
+ source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
778
+ source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
779
+
780
+ source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
781
+ source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
782
+ source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
783
+
784
+ # target graph
785
+ # print('x size',x.size(),adj1.size())
786
+ graph = self.target_featuremap_2_graph(x)
787
+
788
+ source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
789
+ # graph combine 1
790
+ graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
791
+ graph = self.fc_graph.forward(graph,relu=True)
792
+
793
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
794
+
795
+ source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
796
+ # graph combine 2
797
+ graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
798
+ graph = self.fc_graph.forward(graph, relu=True)
799
+
800
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
801
+
802
+ source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
803
+ # graph combine 3
804
+ graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
805
+ graph = self.fc_graph.forward(graph, relu=True)
806
+
807
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
808
+
809
+ # print(graph.size(),x.size())
810
+
811
+ graph = self.target_graph_2_fea.forward(graph, x)
812
+ x = self.target_skip_conv(x)
813
+ x = x + graph
814
+
815
+ ###
816
+ x = self.semantic(x)
817
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
818
+
819
+ return x
820
+
821
+ def similarity_trans(self,source,target):
822
+ sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
823
+ sim = F.softmax(sim, dim=-1)
824
+ return torch.matmul(sim, source)
825
+
826
+ def load_source_model(self,state_dict):
827
+ own_state = self.state_dict()
828
+ # for name inshop_cos own_state:
829
+ # print name
830
+ new_state_dict = OrderedDict()
831
+ for name, param in state_dict.items():
832
+ name = name.replace('module.', '')
833
+
834
+ if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
835
+ if 'featuremap_2_graph' in name:
836
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
837
+ else:
838
+ name = name.replace('graph','source_graph')
839
+ new_state_dict[name] = 0
840
+ if name not in own_state:
841
+ if 'num_batch' in name:
842
+ continue
843
+ print('unexpected key "{}" in state_dict'
844
+ .format(name))
845
+ continue
846
+ # if isinstance(param, own_state):
847
+ if isinstance(param, Parameter):
848
+ # backwards compatibility for serialized parameters
849
+ param = param.data
850
+ try:
851
+ own_state[name].copy_(param)
852
+ except:
853
+ print('While copying the parameter named {}, whose dimensions in the model are'
854
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
855
+ name, own_state[name].size(), param.size()))
856
+ continue # i add inshop_cos 2018/02/01
857
+ own_state[name].copy_(param)
858
+ # print 'copying %s' %name
859
+
860
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
861
+ if len(missing) > 0:
862
+ print('missing keys in state_dict: "{}"'.format(missing))
863
+
864
+
865
+ class deeplab_xception_transfer_projection_synBN_savemem(deeplab_xception_transfer_basemodel_synBN_savememory):
866
+ def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
867
+ transfer_graph=None, source_classes=20):
868
+ super(deeplab_xception_transfer_projection_synBN_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
869
+ os=os, input_channels=input_channels,
870
+ hidden_layers=hidden_layers, out_channels=out_channels, )
871
+ self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
872
+ nodes=source_classes)
873
+ self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
874
+ self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
875
+ self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
876
+ self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
877
+ begin_nodes=source_classes,end_nodes=n_classes)
878
+ self.fc_graph = gcn.GraphConvolution(hidden_layers*3 ,hidden_layers)
879
+
880
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
881
+ x, low_level_features = self.xception_features(input)
882
+ # print(x.size())
883
+ x1 = self.aspp1(x)
884
+ x2 = self.aspp2(x)
885
+ x3 = self.aspp3(x)
886
+ x4 = self.aspp4(x)
887
+ x5 = self.global_avg_pool(x)
888
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
889
+
890
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
891
+
892
+ x = self.concat_projection_conv1(x)
893
+ x = self.concat_projection_bn1(x)
894
+ x = self.relu(x)
895
+ # print(x.size())
896
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
897
+
898
+ low_level_features = self.feature_projection_conv1(low_level_features)
899
+ low_level_features = self.feature_projection_bn1(low_level_features)
900
+ low_level_features = self.relu(low_level_features)
901
+ # print(low_level_features.size())
902
+ # print(x.size())
903
+ x = torch.cat((x, low_level_features), dim=1)
904
+ x = self.decoder(x)
905
+
906
+ ### add graph
907
+ # source graph
908
+ source_graph = self.source_featuremap_2_graph(x)
909
+ source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
910
+ source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
911
+ source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
912
+
913
+ source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
914
+ source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
915
+ source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
916
+
917
+ # target graph
918
+ # print('x size',x.size(),adj1.size())
919
+ graph = self.target_featuremap_2_graph(x)
920
+
921
+ source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
922
+ # graph combine 1
923
+ graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
924
+ graph = self.fc_graph.forward(graph,relu=True)
925
+
926
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
927
+
928
+ source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
929
+ # graph combine 2
930
+ graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
931
+ graph = self.fc_graph.forward(graph, relu=True)
932
+
933
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
934
+
935
+ source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
936
+ # graph combine 3
937
+ graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
938
+ graph = self.fc_graph.forward(graph, relu=True)
939
+
940
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
941
+
942
+ # print(graph.size(),x.size())
943
+
944
+ graph = self.target_graph_2_fea.forward(graph, x)
945
+ x = self.target_skip_conv(x)
946
+ x = x + graph
947
+
948
+ ###
949
+ x = self.semantic(x)
950
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
951
+
952
+ return x
953
+
954
+ def similarity_trans(self,source,target):
955
+ sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
956
+ sim = F.softmax(sim, dim=-1)
957
+ return torch.matmul(sim, source)
958
+
959
+ def load_source_model(self,state_dict):
960
+ own_state = self.state_dict()
961
+ # for name inshop_cos own_state:
962
+ # print name
963
+ new_state_dict = OrderedDict()
964
+ for name, param in state_dict.items():
965
+ name = name.replace('module.', '')
966
+
967
+ if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
968
+ if 'featuremap_2_graph' in name:
969
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
970
+ else:
971
+ name = name.replace('graph','source_graph')
972
+ new_state_dict[name] = 0
973
+ if name not in own_state:
974
+ if 'num_batch' in name:
975
+ continue
976
+ print('unexpected key "{}" in state_dict'
977
+ .format(name))
978
+ continue
979
+ # if isinstance(param, own_state):
980
+ if isinstance(param, Parameter):
981
+ # backwards compatibility for serialized parameters
982
+ param = param.data
983
+ try:
984
+ own_state[name].copy_(param)
985
+ except:
986
+ print('While copying the parameter named {}, whose dimensions in the model are'
987
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
988
+ name, own_state[name].size(), param.size()))
989
+ continue # i add inshop_cos 2018/02/01
990
+ own_state[name].copy_(param)
991
+ # print 'copying %s' %name
992
+
993
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
994
+ if len(missing) > 0:
995
+ print('missing keys in state_dict: "{}"'.format(missing))
996
+
997
+
998
+ # if __name__ == '__main__':
999
+ # net = deeplab_xception_transfer_projection_v3v5_more_savemem()
1000
+ # img = torch.rand((2,3,128,128))
1001
+ # net.eval()
1002
+ # a = torch.rand((1,1,7,7))
1003
+ # net.forward(img, adj1_target=a)
TryYours-Virtual-Try-On/Graphonomy-master/networks/deeplab_xception_universal.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from collections import OrderedDict
7
+ from torch.nn import Parameter
8
+ from networks import deeplab_xception, gcn, deeplab_xception_synBN
9
+
10
+
11
+
12
+ class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus):
13
+ def __init__(self, nInputChannels=3, n_classes=7, os=16, input_channels=256, hidden_layers=128, out_channels=256,
14
+ source_classes=20, transfer_graph=None):
15
+ super(deeplab_xception_transfer_basemodel_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
16
+ os=os,)
17
+
18
+ def load_source_model(self,state_dict):
19
+ own_state = self.state_dict()
20
+ # for name inshop_cos own_state:
21
+ # print name
22
+ new_state_dict = OrderedDict()
23
+ for name, param in state_dict.items():
24
+ name = name.replace('module.', '')
25
+ if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name \
26
+ and 'transpose_graph' not in name and 'middle' not in name:
27
+ if 'featuremap_2_graph' in name:
28
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
29
+ else:
30
+ name = name.replace('graph','source_graph')
31
+ new_state_dict[name] = 0
32
+ if name not in own_state:
33
+ if 'num_batch' in name:
34
+ continue
35
+ print('unexpected key "{}" in state_dict'
36
+ .format(name))
37
+ continue
38
+ # if isinstance(param, own_state):
39
+ if isinstance(param, Parameter):
40
+ # backwards compatibility for serialized parameters
41
+ param = param.data
42
+ try:
43
+ own_state[name].copy_(param)
44
+ except:
45
+ print('While copying the parameter named {}, whose dimensions in the model are'
46
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
47
+ name, own_state[name].size(), param.size()))
48
+ continue # i add inshop_cos 2018/02/01
49
+ own_state[name].copy_(param)
50
+ # print 'copying %s' %name
51
+
52
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
53
+ if len(missing) > 0:
54
+ print('missing keys in state_dict: "{}"'.format(missing))
55
+
56
+ def get_target_parameter(self):
57
+ l = []
58
+ other = []
59
+ for name, k in self.named_parameters():
60
+ if 'target' in name or 'semantic' in name:
61
+ l.append(k)
62
+ else:
63
+ other.append(k)
64
+ return l, other
65
+
66
+ def get_semantic_parameter(self):
67
+ l = []
68
+ for name, k in self.named_parameters():
69
+ if 'semantic' in name:
70
+ l.append(k)
71
+ return l
72
+
73
+ def get_source_parameter(self):
74
+ l = []
75
+ for name, k in self.named_parameters():
76
+ if 'source' in name:
77
+ l.append(k)
78
+ return l
79
+
80
+ def top_forward(self, input, adj1_target=None, adj2_source=None,adj3_transfer=None ):
81
+ x, low_level_features = self.xception_features(input)
82
+ # print(x.size())
83
+ x1 = self.aspp1(x)
84
+ x2 = self.aspp2(x)
85
+ x3 = self.aspp3(x)
86
+ x4 = self.aspp4(x)
87
+ x5 = self.global_avg_pool(x)
88
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
89
+
90
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
91
+
92
+ x = self.concat_projection_conv1(x)
93
+ x = self.concat_projection_bn1(x)
94
+ x = self.relu(x)
95
+ # print(x.size())
96
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
97
+
98
+ low_level_features = self.feature_projection_conv1(low_level_features)
99
+ low_level_features = self.feature_projection_bn1(low_level_features)
100
+ low_level_features = self.relu(low_level_features)
101
+ # print(low_level_features.size())
102
+ # print(x.size())
103
+ x = torch.cat((x, low_level_features), dim=1)
104
+ x = self.decoder(x)
105
+
106
+ ### source graph
107
+ source_graph = self.source_featuremap_2_graph(x)
108
+
109
+ source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
110
+ source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
111
+ source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
112
+
113
+ ### target source
114
+ graph = self.target_featuremap_2_graph(x)
115
+
116
+ # graph combine
117
+ # print(graph.size(),source_2_target_graph.size())
118
+ # graph = self.fc_graph.forward(graph,relu=True)
119
+ # print(graph.size())
120
+
121
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
122
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
123
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
124
+
125
+
126
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
127
+ x, low_level_features = self.xception_features(input)
128
+ # print(x.size())
129
+ x1 = self.aspp1(x)
130
+ x2 = self.aspp2(x)
131
+ x3 = self.aspp3(x)
132
+ x4 = self.aspp4(x)
133
+ x5 = self.global_avg_pool(x)
134
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
135
+
136
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
137
+
138
+ x = self.concat_projection_conv1(x)
139
+ x = self.concat_projection_bn1(x)
140
+ x = self.relu(x)
141
+ # print(x.size())
142
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
143
+
144
+ low_level_features = self.feature_projection_conv1(low_level_features)
145
+ low_level_features = self.feature_projection_bn1(low_level_features)
146
+ low_level_features = self.relu(low_level_features)
147
+ # print(low_level_features.size())
148
+ # print(x.size())
149
+ x = torch.cat((x, low_level_features), dim=1)
150
+ x = self.decoder(x)
151
+
152
+ ### add graph
153
+
154
+
155
+ # target graph
156
+ # print('x size',x.size(),adj1.size())
157
+ graph = self.target_featuremap_2_graph(x)
158
+
159
+ # graph combine
160
+ # print(graph.size(),source_2_target_graph.size())
161
+ # graph = self.fc_graph.forward(graph,relu=True)
162
+ # print(graph.size())
163
+
164
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
165
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
166
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
167
+ # print(graph.size(),x.size())
168
+ # graph = self.gcn_encode.forward(graph,relu=True)
169
+ # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
170
+ # graph = self.gcn_decode.forward(graph,relu=True)
171
+ graph = self.target_graph_2_fea.forward(graph, x)
172
+ x = self.target_skip_conv(x)
173
+ x = x + graph
174
+
175
+ ###
176
+ x = self.semantic(x)
177
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
178
+
179
+ return x
180
+
181
+
182
+ class deeplab_xception_transfer_basemodel_savememory_synbn(deeplab_xception_synBN.DeepLabv3_plus):
183
+ def __init__(self, nInputChannels=3, n_classes=7, os=16, input_channels=256, hidden_layers=128, out_channels=256,
184
+ source_classes=20, transfer_graph=None):
185
+ super(deeplab_xception_transfer_basemodel_savememory_synbn, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
186
+ os=os,)
187
+
188
+
189
+ def load_source_model(self,state_dict):
190
+ own_state = self.state_dict()
191
+ # for name inshop_cos own_state:
192
+ # print name
193
+ new_state_dict = OrderedDict()
194
+ for name, param in state_dict.items():
195
+ name = name.replace('module.', '')
196
+ if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name \
197
+ and 'transpose_graph' not in name and 'middle' not in name:
198
+ if 'featuremap_2_graph' in name:
199
+ name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
200
+ else:
201
+ name = name.replace('graph','source_graph')
202
+ new_state_dict[name] = 0
203
+ if name not in own_state:
204
+ if 'num_batch' in name:
205
+ continue
206
+ print('unexpected key "{}" in state_dict'
207
+ .format(name))
208
+ continue
209
+ # if isinstance(param, own_state):
210
+ if isinstance(param, Parameter):
211
+ # backwards compatibility for serialized parameters
212
+ param = param.data
213
+ try:
214
+ own_state[name].copy_(param)
215
+ except:
216
+ print('While copying the parameter named {}, whose dimensions in the model are'
217
+ ' {} and whose dimensions in the checkpoint are {}, ...'.format(
218
+ name, own_state[name].size(), param.size()))
219
+ continue # i add inshop_cos 2018/02/01
220
+ own_state[name].copy_(param)
221
+ # print 'copying %s' %name
222
+
223
+ missing = set(own_state.keys()) - set(new_state_dict.keys())
224
+ if len(missing) > 0:
225
+ print('missing keys in state_dict: "{}"'.format(missing))
226
+
227
+ def get_target_parameter(self):
228
+ l = []
229
+ other = []
230
+ for name, k in self.named_parameters():
231
+ if 'target' in name or 'semantic' in name:
232
+ l.append(k)
233
+ else:
234
+ other.append(k)
235
+ return l, other
236
+
237
+ def get_semantic_parameter(self):
238
+ l = []
239
+ for name, k in self.named_parameters():
240
+ if 'semantic' in name:
241
+ l.append(k)
242
+ return l
243
+
244
+ def get_source_parameter(self):
245
+ l = []
246
+ for name, k in self.named_parameters():
247
+ if 'source' in name:
248
+ l.append(k)
249
+ return l
250
+
251
+ def top_forward(self, input, adj1_target=None, adj2_source=None,adj3_transfer=None ):
252
+ x, low_level_features = self.xception_features(input)
253
+ # print(x.size())
254
+ x1 = self.aspp1(x)
255
+ x2 = self.aspp2(x)
256
+ x3 = self.aspp3(x)
257
+ x4 = self.aspp4(x)
258
+ x5 = self.global_avg_pool(x)
259
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
260
+
261
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
262
+
263
+ x = self.concat_projection_conv1(x)
264
+ x = self.concat_projection_bn1(x)
265
+ x = self.relu(x)
266
+ # print(x.size())
267
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
268
+
269
+ low_level_features = self.feature_projection_conv1(low_level_features)
270
+ low_level_features = self.feature_projection_bn1(low_level_features)
271
+ low_level_features = self.relu(low_level_features)
272
+ # print(low_level_features.size())
273
+ # print(x.size())
274
+ x = torch.cat((x, low_level_features), dim=1)
275
+ x = self.decoder(x)
276
+
277
+ ### source graph
278
+ source_graph = self.source_featuremap_2_graph(x)
279
+
280
+ source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
281
+ source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
282
+ source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
283
+
284
+ ### target source
285
+ graph = self.target_featuremap_2_graph(x)
286
+
287
+ # graph combine
288
+ # print(graph.size(),source_2_target_graph.size())
289
+ # graph = self.fc_graph.forward(graph,relu=True)
290
+ # print(graph.size())
291
+
292
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
293
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
294
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
295
+
296
+
297
+ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
298
+ x, low_level_features = self.xception_features(input)
299
+ # print(x.size())
300
+ x1 = self.aspp1(x)
301
+ x2 = self.aspp2(x)
302
+ x3 = self.aspp3(x)
303
+ x4 = self.aspp4(x)
304
+ x5 = self.global_avg_pool(x)
305
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
306
+
307
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
308
+
309
+ x = self.concat_projection_conv1(x)
310
+ x = self.concat_projection_bn1(x)
311
+ x = self.relu(x)
312
+ # print(x.size())
313
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
314
+
315
+ low_level_features = self.feature_projection_conv1(low_level_features)
316
+ low_level_features = self.feature_projection_bn1(low_level_features)
317
+ low_level_features = self.relu(low_level_features)
318
+ # print(low_level_features.size())
319
+ # print(x.size())
320
+ x = torch.cat((x, low_level_features), dim=1)
321
+ x = self.decoder(x)
322
+
323
+ ### add graph
324
+
325
+
326
+ # target graph
327
+ # print('x size',x.size(),adj1.size())
328
+ graph = self.target_featuremap_2_graph(x)
329
+
330
+ # graph combine
331
+ # print(graph.size(),source_2_target_graph.size())
332
+ # graph = self.fc_graph.forward(graph,relu=True)
333
+ # print(graph.size())
334
+
335
+ graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
336
+ graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
337
+ graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
338
+ # print(graph.size(),x.size())
339
+ # graph = self.gcn_encode.forward(graph,relu=True)
340
+ # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
341
+ # graph = self.gcn_decode.forward(graph,relu=True)
342
+ graph = self.target_graph_2_fea.forward(graph, x)
343
+ x = self.target_skip_conv(x)
344
+ x = x + graph
345
+
346
+ ###
347
+ x = self.semantic(x)
348
+ x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
349
+
350
+ return x
351
+
352
+
353
+ class deeplab_xception_end2end_3d(deeplab_xception_transfer_basemodel_savememory):
354
+ def __init__(self, nInputChannels=3, n_classes=20, os=16, input_channels=256, hidden_layers=128, out_channels=256,
355
+ source_classes=7, middle_classes=18, transfer_graph=None):
356
+ super(deeplab_xception_end2end_3d, self).__init__(nInputChannels=nInputChannels,
357
+ n_classes=n_classes,
358
+ os=os, )
359
+ ### source graph
360
+ self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
361
+ hidden_layers=hidden_layers,
362
+ nodes=source_classes)
363
+ self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
364
+ self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
365
+ self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
366
+
367
+ self.source_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
368
+ output_channels=out_channels,
369
+ hidden_layers=hidden_layers, nodes=source_classes
370
+ )
371
+ self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
372
+ nn.ReLU(True)])
373
+ self.source_semantic = nn.Conv2d(out_channels,source_classes,1)
374
+ self.middle_semantic = nn.Conv2d(out_channels, middle_classes, 1)
375
+
376
+ ### target graph 1
377
+ self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
378
+ hidden_layers=hidden_layers,
379
+ nodes=n_classes)
380
+ self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
381
+ self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
382
+ self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
383
+
384
+ self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
385
+ output_channels=out_channels,
386
+ hidden_layers=hidden_layers, nodes=n_classes
387
+ )
388
+ self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
389
+ nn.ReLU(True)])
390
+
391
+ ### middle
392
+ self.middle_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
393
+ hidden_layers=hidden_layers,
394
+ nodes=middle_classes)
395
+ self.middle_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
396
+ self.middle_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
397
+ self.middle_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
398
+
399
+ self.middle_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
400
+ output_channels=out_channels,
401
+ hidden_layers=hidden_layers, nodes=n_classes
402
+ )
403
+ self.middle_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
404
+ nn.ReLU(True)])
405
+
406
+ ### multi transpose
407
+ self.transpose_graph_source2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
408
+ adj=transfer_graph,
409
+ begin_nodes=source_classes, end_nodes=n_classes)
410
+ self.transpose_graph_target2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
411
+ adj=transfer_graph,
412
+ begin_nodes=n_classes, end_nodes=source_classes)
413
+
414
+ self.transpose_graph_middle2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
415
+ adj=transfer_graph,
416
+ begin_nodes=middle_classes, end_nodes=source_classes)
417
+ self.transpose_graph_middle2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
418
+ adj=transfer_graph,
419
+ begin_nodes=middle_classes, end_nodes=source_classes)
420
+
421
+ self.transpose_graph_source2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
422
+ adj=transfer_graph,
423
+ begin_nodes=source_classes, end_nodes=middle_classes)
424
+ self.transpose_graph_target2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
425
+ adj=transfer_graph,
426
+ begin_nodes=n_classes, end_nodes=middle_classes)
427
+
428
+
429
+ self.fc_graph_source = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
430
+ self.fc_graph_target = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
431
+ self.fc_graph_middle = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
432
+
433
+ def freeze_totally_bn(self):
434
+ for m in self.modules():
435
+ if isinstance(m, nn.BatchNorm2d):
436
+ m.eval()
437
+ m.weight.requires_grad = False
438
+ m.bias.requires_grad = False
439
+
440
+ def freeze_backbone_bn(self):
441
+ for m in self.xception_features.modules():
442
+ if isinstance(m, nn.BatchNorm2d):
443
+ m.eval()
444
+ m.weight.requires_grad = False
445
+ m.bias.requires_grad = False
446
+
447
+ def top_forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer_s2t=None, adj3_transfer_t2s=None,
448
+ adj4_middle=None,adj5_transfer_s2m=None,adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
449
+ x, low_level_features = self.xception_features(input)
450
+ # print(x.size())
451
+ x1 = self.aspp1(x)
452
+ x2 = self.aspp2(x)
453
+ x3 = self.aspp3(x)
454
+ x4 = self.aspp4(x)
455
+ x5 = self.global_avg_pool(x)
456
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
457
+
458
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
459
+
460
+ x = self.concat_projection_conv1(x)
461
+ x = self.concat_projection_bn1(x)
462
+ x = self.relu(x)
463
+ # print(x.size())
464
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
465
+
466
+ low_level_features = self.feature_projection_conv1(low_level_features)
467
+ low_level_features = self.feature_projection_bn1(low_level_features)
468
+ low_level_features = self.relu(low_level_features)
469
+ # print(low_level_features.size())
470
+ # print(x.size())
471
+ x = torch.cat((x, low_level_features), dim=1)
472
+ x = self.decoder(x)
473
+
474
+ ### source graph
475
+ source_graph = self.source_featuremap_2_graph(x)
476
+ ### target source
477
+ target_graph = self.target_featuremap_2_graph(x)
478
+ ### middle source
479
+ middle_graph = self.middle_featuremap_2_graph(x)
480
+
481
+ ##### end2end multi task
482
+
483
+ ### first task
484
+ # print(source_graph.size(),target_graph.size())
485
+ source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
486
+ target_graph1 = self.target_graph_conv1.forward(target_graph, adj=adj1_target, relu=True)
487
+ middle_graph1 = self.target_graph_conv1.forward(middle_graph, adj=adj4_middle, relu=True)
488
+
489
+ # source 2 target & middle
490
+ source_2_target_graph1_v5 = self.transpose_graph_source2target.forward(source_graph1, adj=adj3_transfer_s2t,
491
+ relu=True)
492
+ source_2_middle_graph1_v5 = self.transpose_graph_source2middle.forward(source_graph1,adj=adj5_transfer_s2m,
493
+ relu=True)
494
+ # target 2 source & middle
495
+ target_2_source_graph1_v5 = self.transpose_graph_target2source.forward(target_graph1, adj=adj3_transfer_t2s,
496
+ relu=True)
497
+ target_2_middle_graph1_v5 = self.transpose_graph_target2middle.forward(target_graph1, adj=adj6_transfer_t2m,
498
+ relu=True)
499
+ # middle 2 source & target
500
+ middle_2_source_graph1_v5 = self.transpose_graph_middle2source.forward(middle_graph1, adj=adj5_transfer_m2s,
501
+ relu=True)
502
+ middle_2_target_graph1_v5 = self.transpose_graph_middle2target.forward(middle_graph1, adj=adj6_transfer_m2t,
503
+ relu=True)
504
+ # source 2 middle target
505
+ source_2_target_graph1 = self.similarity_trans(source_graph1, target_graph1)
506
+ source_2_middle_graph1 = self.similarity_trans(source_graph1, middle_graph1)
507
+ # target 2 source middle
508
+ target_2_source_graph1 = self.similarity_trans(target_graph1, source_graph1)
509
+ target_2_middle_graph1 = self.similarity_trans(target_graph1, middle_graph1)
510
+ # middle 2 source target
511
+ middle_2_source_graph1 = self.similarity_trans(middle_graph1, source_graph1)
512
+ middle_2_target_graph1 = self.similarity_trans(middle_graph1, target_graph1)
513
+
514
+ ## concat
515
+ # print(source_graph1.size(), target_2_source_graph1.size(), )
516
+ source_graph1 = torch.cat(
517
+ (source_graph1, target_2_source_graph1, target_2_source_graph1_v5,
518
+ middle_2_source_graph1, middle_2_source_graph1_v5), dim=-1)
519
+ source_graph1 = self.fc_graph_source.forward(source_graph1, relu=True)
520
+ # target
521
+ target_graph1 = torch.cat(
522
+ (target_graph1, source_2_target_graph1, source_2_target_graph1_v5,
523
+ middle_2_target_graph1, middle_2_target_graph1_v5), dim=-1)
524
+ target_graph1 = self.fc_graph_target.forward(target_graph1, relu=True)
525
+ # middle
526
+ middle_graph1 = torch.cat((middle_graph1, source_2_middle_graph1, source_2_middle_graph1_v5,
527
+ target_2_middle_graph1, target_2_middle_graph1_v5), dim=-1)
528
+ middle_graph1 = self.fc_graph_middle.forward(middle_graph1, relu=True)
529
+
530
+
531
+ ### seconde task
532
+ source_graph2 = self.source_graph_conv1.forward(source_graph1, adj=adj2_source, relu=True)
533
+ target_graph2 = self.target_graph_conv1.forward(target_graph1, adj=adj1_target, relu=True)
534
+ middle_graph2 = self.target_graph_conv1.forward(middle_graph1, adj=adj4_middle, relu=True)
535
+
536
+ # source 2 target & middle
537
+ source_2_target_graph2_v5 = self.transpose_graph_source2target.forward(source_graph2, adj=adj3_transfer_s2t,
538
+ relu=True)
539
+ source_2_middle_graph2_v5 = self.transpose_graph_source2middle.forward(source_graph2, adj=adj5_transfer_s2m,
540
+ relu=True)
541
+ # target 2 source & middle
542
+ target_2_source_graph2_v5 = self.transpose_graph_target2source.forward(target_graph2, adj=adj3_transfer_t2s,
543
+ relu=True)
544
+ target_2_middle_graph2_v5 = self.transpose_graph_target2middle.forward(target_graph2, adj=adj6_transfer_t2m,
545
+ relu=True)
546
+ # middle 2 source & target
547
+ middle_2_source_graph2_v5 = self.transpose_graph_middle2source.forward(middle_graph2, adj=adj5_transfer_m2s,
548
+ relu=True)
549
+ middle_2_target_graph2_v5 = self.transpose_graph_middle2target.forward(middle_graph2, adj=adj6_transfer_m2t,
550
+ relu=True)
551
+ # source 2 middle target
552
+ source_2_target_graph2 = self.similarity_trans(source_graph2, target_graph2)
553
+ source_2_middle_graph2 = self.similarity_trans(source_graph2, middle_graph2)
554
+ # target 2 source middle
555
+ target_2_source_graph2 = self.similarity_trans(target_graph2, source_graph2)
556
+ target_2_middle_graph2 = self.similarity_trans(target_graph2, middle_graph2)
557
+ # middle 2 source target
558
+ middle_2_source_graph2 = self.similarity_trans(middle_graph2, source_graph2)
559
+ middle_2_target_graph2 = self.similarity_trans(middle_graph2, target_graph2)
560
+
561
+ ## concat
562
+ # print(source_graph1.size(), target_2_source_graph1.size(), )
563
+ source_graph2 = torch.cat(
564
+ (source_graph2, target_2_source_graph2, target_2_source_graph2_v5,
565
+ middle_2_source_graph2, middle_2_source_graph2_v5), dim=-1)
566
+ source_graph2 = self.fc_graph_source.forward(source_graph2, relu=True)
567
+ # target
568
+ target_graph2 = torch.cat(
569
+ (target_graph2, source_2_target_graph2, source_2_target_graph2_v5,
570
+ middle_2_target_graph2, middle_2_target_graph2_v5), dim=-1)
571
+ target_graph2 = self.fc_graph_target.forward(target_graph2, relu=True)
572
+ # middle
573
+ middle_graph2 = torch.cat((middle_graph2, source_2_middle_graph2, source_2_middle_graph2_v5,
574
+ target_2_middle_graph2, target_2_middle_graph2_v5), dim=-1)
575
+ middle_graph2 = self.fc_graph_middle.forward(middle_graph2, relu=True)
576
+
577
+
578
+ ### third task
579
+ source_graph3 = self.source_graph_conv1.forward(source_graph2, adj=adj2_source, relu=True)
580
+ target_graph3 = self.target_graph_conv1.forward(target_graph2, adj=adj1_target, relu=True)
581
+ middle_graph3 = self.target_graph_conv1.forward(middle_graph2, adj=adj4_middle, relu=True)
582
+
583
+ # source 2 target & middle
584
+ source_2_target_graph3_v5 = self.transpose_graph_source2target.forward(source_graph3, adj=adj3_transfer_s2t,
585
+ relu=True)
586
+ source_2_middle_graph3_v5 = self.transpose_graph_source2middle.forward(source_graph3, adj=adj5_transfer_s2m,
587
+ relu=True)
588
+ # target 2 source & middle
589
+ target_2_source_graph3_v5 = self.transpose_graph_target2source.forward(target_graph3, adj=adj3_transfer_t2s,
590
+ relu=True)
591
+ target_2_middle_graph3_v5 = self.transpose_graph_target2middle.forward(target_graph3, adj=adj6_transfer_t2m,
592
+ relu=True)
593
+ # middle 2 source & target
594
+ middle_2_source_graph3_v5 = self.transpose_graph_middle2source.forward(middle_graph3, adj=adj5_transfer_m2s,
595
+ relu=True)
596
+ middle_2_target_graph3_v5 = self.transpose_graph_middle2target.forward(middle_graph3, adj=adj6_transfer_m2t,
597
+ relu=True)
598
+ # source 2 middle target
599
+ source_2_target_graph3 = self.similarity_trans(source_graph3, target_graph3)
600
+ source_2_middle_graph3 = self.similarity_trans(source_graph3, middle_graph3)
601
+ # target 2 source middle
602
+ target_2_source_graph3 = self.similarity_trans(target_graph3, source_graph3)
603
+ target_2_middle_graph3 = self.similarity_trans(target_graph3, middle_graph3)
604
+ # middle 2 source target
605
+ middle_2_source_graph3 = self.similarity_trans(middle_graph3, source_graph3)
606
+ middle_2_target_graph3 = self.similarity_trans(middle_graph3, target_graph3)
607
+
608
+ ## concat
609
+ # print(source_graph1.size(), target_2_source_graph1.size(), )
610
+ source_graph3 = torch.cat(
611
+ (source_graph3, target_2_source_graph3, target_2_source_graph3_v5,
612
+ middle_2_source_graph3, middle_2_source_graph3_v5), dim=-1)
613
+ source_graph3 = self.fc_graph_source.forward(source_graph3, relu=True)
614
+ # target
615
+ target_graph3 = torch.cat(
616
+ (target_graph3, source_2_target_graph3, source_2_target_graph3_v5,
617
+ middle_2_target_graph3, middle_2_target_graph3_v5), dim=-1)
618
+ target_graph3 = self.fc_graph_target.forward(target_graph3, relu=True)
619
+ # middle
620
+ middle_graph3 = torch.cat((middle_graph3, source_2_middle_graph3, source_2_middle_graph3_v5,
621
+ target_2_middle_graph3, target_2_middle_graph3_v5), dim=-1)
622
+ middle_graph3 = self.fc_graph_middle.forward(middle_graph3, relu=True)
623
+
624
+ return source_graph3, target_graph3, middle_graph3, x
625
+
626
+ def similarity_trans(self,source,target):
627
+ sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
628
+ sim = F.softmax(sim, dim=-1)
629
+ return torch.matmul(sim, source)
630
+
631
+ def bottom_forward_source(self, input, source_graph):
632
+ # print('input size')
633
+ # print(input.size())
634
+ # print(source_graph.size())
635
+ graph = self.source_graph_2_fea.forward(source_graph, input)
636
+ x = self.source_skip_conv(input)
637
+ x = x + graph
638
+ x = self.source_semantic(x)
639
+ return x
640
+
641
+ def bottom_forward_target(self, input, target_graph):
642
+ graph = self.target_graph_2_fea.forward(target_graph, input)
643
+ x = self.target_skip_conv(input)
644
+ x = x + graph
645
+ x = self.semantic(x)
646
+ return x
647
+
648
+ def bottom_forward_middle(self, input, target_graph):
649
+ graph = self.middle_graph_2_fea.forward(target_graph, input)
650
+ x = self.middle_skip_conv(input)
651
+ x = x + graph
652
+ x = self.middle_semantic(x)
653
+ return x
654
+
655
+ def forward(self, input_source, input_target=None, input_middle=None, adj1_target=None, adj2_source=None,
656
+ adj3_transfer_s2t=None, adj3_transfer_t2s=None, adj4_middle=None,adj5_transfer_s2m=None,
657
+ adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
658
+ if input_source is None and input_target is not None and input_middle is None:
659
+ # target
660
+ target_batch = input_target.size(0)
661
+ input = input_target
662
+
663
+ source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, adj2_source=adj2_source,
664
+ adj3_transfer_s2t=adj3_transfer_s2t,
665
+ adj3_transfer_t2s=adj3_transfer_t2s,
666
+ adj4_middle=adj4_middle,
667
+ adj5_transfer_s2m=adj5_transfer_s2m,
668
+ adj6_transfer_t2m=adj6_transfer_t2m,
669
+ adj5_transfer_m2s=adj5_transfer_m2s,
670
+ adj6_transfer_m2t=adj6_transfer_m2t)
671
+
672
+ # source_x = self.bottom_forward_source(source_x, source_graph)
673
+ target_x = self.bottom_forward_target(x, target_graph)
674
+
675
+ target_x = F.upsample(target_x, size=input.size()[2:], mode='bilinear', align_corners=True)
676
+ return None, target_x, None
677
+
678
+ if input_source is not None and input_target is None and input_middle is None:
679
+ # source
680
+ source_batch = input_source.size(0)
681
+ source_list = range(source_batch)
682
+ input = input_source
683
+
684
+ source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
685
+ adj2_source=adj2_source,
686
+ adj3_transfer_s2t=adj3_transfer_s2t,
687
+ adj3_transfer_t2s=adj3_transfer_t2s,
688
+ adj4_middle=adj4_middle,
689
+ adj5_transfer_s2m=adj5_transfer_s2m,
690
+ adj6_transfer_t2m=adj6_transfer_t2m,
691
+ adj5_transfer_m2s=adj5_transfer_m2s,
692
+ adj6_transfer_m2t=adj6_transfer_m2t)
693
+
694
+ source_x = self.bottom_forward_source(x, source_graph)
695
+ source_x = F.upsample(source_x, size=input.size()[2:], mode='bilinear', align_corners=True)
696
+ return source_x, None, None
697
+
698
+ if input_middle is not None and input_source is None and input_target is None:
699
+ # middle
700
+ input = input_middle
701
+
702
+ source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
703
+ adj2_source=adj2_source,
704
+ adj3_transfer_s2t=adj3_transfer_s2t,
705
+ adj3_transfer_t2s=adj3_transfer_t2s,
706
+ adj4_middle=adj4_middle,
707
+ adj5_transfer_s2m=adj5_transfer_s2m,
708
+ adj6_transfer_t2m=adj6_transfer_t2m,
709
+ adj5_transfer_m2s=adj5_transfer_m2s,
710
+ adj6_transfer_m2t=adj6_transfer_m2t)
711
+
712
+ middle_x = self.bottom_forward_middle(x, source_graph)
713
+ middle_x = F.upsample(middle_x, size=input.size()[2:], mode='bilinear', align_corners=True)
714
+ return None, None, middle_x
715
+
716
+
717
+ class deeplab_xception_end2end_3d_synbn(deeplab_xception_transfer_basemodel_savememory_synbn):
718
+ def __init__(self, nInputChannels=3, n_classes=20, os=16, input_channels=256, hidden_layers=128, out_channels=256,
719
+ source_classes=7, middle_classes=18, transfer_graph=None):
720
+ super(deeplab_xception_end2end_3d_synbn, self).__init__(nInputChannels=nInputChannels,
721
+ n_classes=n_classes,
722
+ os=os, )
723
+ ### source graph
724
+ self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
725
+ hidden_layers=hidden_layers,
726
+ nodes=source_classes)
727
+ self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
728
+ self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
729
+ self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
730
+
731
+ self.source_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
732
+ output_channels=out_channels,
733
+ hidden_layers=hidden_layers, nodes=source_classes
734
+ )
735
+ self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
736
+ nn.ReLU(True)])
737
+ self.source_semantic = nn.Conv2d(out_channels,source_classes,1)
738
+ self.middle_semantic = nn.Conv2d(out_channels, middle_classes, 1)
739
+
740
+ ### target graph 1
741
+ self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
742
+ hidden_layers=hidden_layers,
743
+ nodes=n_classes)
744
+ self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
745
+ self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
746
+ self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
747
+
748
+ self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
749
+ output_channels=out_channels,
750
+ hidden_layers=hidden_layers, nodes=n_classes
751
+ )
752
+ self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
753
+ nn.ReLU(True)])
754
+
755
+ ### middle
756
+ self.middle_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels,
757
+ hidden_layers=hidden_layers,
758
+ nodes=middle_classes)
759
+ self.middle_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
760
+ self.middle_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
761
+ self.middle_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
762
+
763
+ self.middle_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels,
764
+ output_channels=out_channels,
765
+ hidden_layers=hidden_layers, nodes=n_classes
766
+ )
767
+ self.middle_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
768
+ nn.ReLU(True)])
769
+
770
+ ### multi transpose
771
+ self.transpose_graph_source2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
772
+ adj=transfer_graph,
773
+ begin_nodes=source_classes, end_nodes=n_classes)
774
+ self.transpose_graph_target2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
775
+ adj=transfer_graph,
776
+ begin_nodes=n_classes, end_nodes=source_classes)
777
+
778
+ self.transpose_graph_middle2source = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
779
+ adj=transfer_graph,
780
+ begin_nodes=middle_classes, end_nodes=source_classes)
781
+ self.transpose_graph_middle2target = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
782
+ adj=transfer_graph,
783
+ begin_nodes=middle_classes, end_nodes=source_classes)
784
+
785
+ self.transpose_graph_source2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
786
+ adj=transfer_graph,
787
+ begin_nodes=source_classes, end_nodes=middle_classes)
788
+ self.transpose_graph_target2middle = gcn.Graph_trans(in_features=hidden_layers, out_features=hidden_layers,
789
+ adj=transfer_graph,
790
+ begin_nodes=n_classes, end_nodes=middle_classes)
791
+
792
+
793
+ self.fc_graph_source = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
794
+ self.fc_graph_target = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
795
+ self.fc_graph_middle = gcn.GraphConvolution(hidden_layers * 5, hidden_layers)
796
+
797
+
798
+ def top_forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer_s2t=None, adj3_transfer_t2s=None,
799
+ adj4_middle=None,adj5_transfer_s2m=None,adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
800
+ x, low_level_features = self.xception_features(input)
801
+ # print(x.size())
802
+ x1 = self.aspp1(x)
803
+ x2 = self.aspp2(x)
804
+ x3 = self.aspp3(x)
805
+ x4 = self.aspp4(x)
806
+ x5 = self.global_avg_pool(x)
807
+ x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
808
+
809
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
810
+
811
+ x = self.concat_projection_conv1(x)
812
+ x = self.concat_projection_bn1(x)
813
+ x = self.relu(x)
814
+ # print(x.size())
815
+ x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
816
+
817
+ low_level_features = self.feature_projection_conv1(low_level_features)
818
+ low_level_features = self.feature_projection_bn1(low_level_features)
819
+ low_level_features = self.relu(low_level_features)
820
+ # print(low_level_features.size())
821
+ # print(x.size())
822
+ x = torch.cat((x, low_level_features), dim=1)
823
+ x = self.decoder(x)
824
+
825
+ ### source graph
826
+ source_graph = self.source_featuremap_2_graph(x)
827
+ ### target source
828
+ target_graph = self.target_featuremap_2_graph(x)
829
+ ### middle source
830
+ middle_graph = self.middle_featuremap_2_graph(x)
831
+
832
+ ##### end2end multi task
833
+
834
+ ### first task
835
+ # print(source_graph.size(),target_graph.size())
836
+ source_graph1 = self.source_graph_conv1.forward(source_graph, adj=adj2_source, relu=True)
837
+ target_graph1 = self.target_graph_conv1.forward(target_graph, adj=adj1_target, relu=True)
838
+ middle_graph1 = self.target_graph_conv1.forward(middle_graph, adj=adj4_middle, relu=True)
839
+
840
+ # source 2 target & middle
841
+ source_2_target_graph1_v5 = self.transpose_graph_source2target.forward(source_graph1, adj=adj3_transfer_s2t,
842
+ relu=True)
843
+ source_2_middle_graph1_v5 = self.transpose_graph_source2middle.forward(source_graph1,adj=adj5_transfer_s2m,
844
+ relu=True)
845
+ # target 2 source & middle
846
+ target_2_source_graph1_v5 = self.transpose_graph_target2source.forward(target_graph1, adj=adj3_transfer_t2s,
847
+ relu=True)
848
+ target_2_middle_graph1_v5 = self.transpose_graph_target2middle.forward(target_graph1, adj=adj6_transfer_t2m,
849
+ relu=True)
850
+ # middle 2 source & target
851
+ middle_2_source_graph1_v5 = self.transpose_graph_middle2source.forward(middle_graph1, adj=adj5_transfer_m2s,
852
+ relu=True)
853
+ middle_2_target_graph1_v5 = self.transpose_graph_middle2target.forward(middle_graph1, adj=adj6_transfer_m2t,
854
+ relu=True)
855
+ # source 2 middle target
856
+ source_2_target_graph1 = self.similarity_trans(source_graph1, target_graph1)
857
+ source_2_middle_graph1 = self.similarity_trans(source_graph1, middle_graph1)
858
+ # target 2 source middle
859
+ target_2_source_graph1 = self.similarity_trans(target_graph1, source_graph1)
860
+ target_2_middle_graph1 = self.similarity_trans(target_graph1, middle_graph1)
861
+ # middle 2 source target
862
+ middle_2_source_graph1 = self.similarity_trans(middle_graph1, source_graph1)
863
+ middle_2_target_graph1 = self.similarity_trans(middle_graph1, target_graph1)
864
+
865
+ ## concat
866
+ # print(source_graph1.size(), target_2_source_graph1.size(), )
867
+ source_graph1 = torch.cat(
868
+ (source_graph1, target_2_source_graph1, target_2_source_graph1_v5,
869
+ middle_2_source_graph1, middle_2_source_graph1_v5), dim=-1)
870
+ source_graph1 = self.fc_graph_source.forward(source_graph1, relu=True)
871
+ # target
872
+ target_graph1 = torch.cat(
873
+ (target_graph1, source_2_target_graph1, source_2_target_graph1_v5,
874
+ middle_2_target_graph1, middle_2_target_graph1_v5), dim=-1)
875
+ target_graph1 = self.fc_graph_target.forward(target_graph1, relu=True)
876
+ # middle
877
+ middle_graph1 = torch.cat((middle_graph1, source_2_middle_graph1, source_2_middle_graph1_v5,
878
+ target_2_middle_graph1, target_2_middle_graph1_v5), dim=-1)
879
+ middle_graph1 = self.fc_graph_middle.forward(middle_graph1, relu=True)
880
+
881
+
882
+ ### seconde task
883
+ source_graph2 = self.source_graph_conv1.forward(source_graph1, adj=adj2_source, relu=True)
884
+ target_graph2 = self.target_graph_conv1.forward(target_graph1, adj=adj1_target, relu=True)
885
+ middle_graph2 = self.target_graph_conv1.forward(middle_graph1, adj=adj4_middle, relu=True)
886
+
887
+ # source 2 target & middle
888
+ source_2_target_graph2_v5 = self.transpose_graph_source2target.forward(source_graph2, adj=adj3_transfer_s2t,
889
+ relu=True)
890
+ source_2_middle_graph2_v5 = self.transpose_graph_source2middle.forward(source_graph2, adj=adj5_transfer_s2m,
891
+ relu=True)
892
+ # target 2 source & middle
893
+ target_2_source_graph2_v5 = self.transpose_graph_target2source.forward(target_graph2, adj=adj3_transfer_t2s,
894
+ relu=True)
895
+ target_2_middle_graph2_v5 = self.transpose_graph_target2middle.forward(target_graph2, adj=adj6_transfer_t2m,
896
+ relu=True)
897
+ # middle 2 source & target
898
+ middle_2_source_graph2_v5 = self.transpose_graph_middle2source.forward(middle_graph2, adj=adj5_transfer_m2s,
899
+ relu=True)
900
+ middle_2_target_graph2_v5 = self.transpose_graph_middle2target.forward(middle_graph2, adj=adj6_transfer_m2t,
901
+ relu=True)
902
+ # source 2 middle target
903
+ source_2_target_graph2 = self.similarity_trans(source_graph2, target_graph2)
904
+ source_2_middle_graph2 = self.similarity_trans(source_graph2, middle_graph2)
905
+ # target 2 source middle
906
+ target_2_source_graph2 = self.similarity_trans(target_graph2, source_graph2)
907
+ target_2_middle_graph2 = self.similarity_trans(target_graph2, middle_graph2)
908
+ # middle 2 source target
909
+ middle_2_source_graph2 = self.similarity_trans(middle_graph2, source_graph2)
910
+ middle_2_target_graph2 = self.similarity_trans(middle_graph2, target_graph2)
911
+
912
+ ## concat
913
+ # print(source_graph1.size(), target_2_source_graph1.size(), )
914
+ source_graph2 = torch.cat(
915
+ (source_graph2, target_2_source_graph2, target_2_source_graph2_v5,
916
+ middle_2_source_graph2, middle_2_source_graph2_v5), dim=-1)
917
+ source_graph2 = self.fc_graph_source.forward(source_graph2, relu=True)
918
+ # target
919
+ target_graph2 = torch.cat(
920
+ (target_graph2, source_2_target_graph2, source_2_target_graph2_v5,
921
+ middle_2_target_graph2, middle_2_target_graph2_v5), dim=-1)
922
+ target_graph2 = self.fc_graph_target.forward(target_graph2, relu=True)
923
+ # middle
924
+ middle_graph2 = torch.cat((middle_graph2, source_2_middle_graph2, source_2_middle_graph2_v5,
925
+ target_2_middle_graph2, target_2_middle_graph2_v5), dim=-1)
926
+ middle_graph2 = self.fc_graph_middle.forward(middle_graph2, relu=True)
927
+
928
+
929
+ ### third task
930
+ source_graph3 = self.source_graph_conv1.forward(source_graph2, adj=adj2_source, relu=True)
931
+ target_graph3 = self.target_graph_conv1.forward(target_graph2, adj=adj1_target, relu=True)
932
+ middle_graph3 = self.target_graph_conv1.forward(middle_graph2, adj=adj4_middle, relu=True)
933
+
934
+ # source 2 target & middle
935
+ source_2_target_graph3_v5 = self.transpose_graph_source2target.forward(source_graph3, adj=adj3_transfer_s2t,
936
+ relu=True)
937
+ source_2_middle_graph3_v5 = self.transpose_graph_source2middle.forward(source_graph3, adj=adj5_transfer_s2m,
938
+ relu=True)
939
+ # target 2 source & middle
940
+ target_2_source_graph3_v5 = self.transpose_graph_target2source.forward(target_graph3, adj=adj3_transfer_t2s,
941
+ relu=True)
942
+ target_2_middle_graph3_v5 = self.transpose_graph_target2middle.forward(target_graph3, adj=adj6_transfer_t2m,
943
+ relu=True)
944
+ # middle 2 source & target
945
+ middle_2_source_graph3_v5 = self.transpose_graph_middle2source.forward(middle_graph3, adj=adj5_transfer_m2s,
946
+ relu=True)
947
+ middle_2_target_graph3_v5 = self.transpose_graph_middle2target.forward(middle_graph3, adj=adj6_transfer_m2t,
948
+ relu=True)
949
+ # source 2 middle target
950
+ source_2_target_graph3 = self.similarity_trans(source_graph3, target_graph3)
951
+ source_2_middle_graph3 = self.similarity_trans(source_graph3, middle_graph3)
952
+ # target 2 source middle
953
+ target_2_source_graph3 = self.similarity_trans(target_graph3, source_graph3)
954
+ target_2_middle_graph3 = self.similarity_trans(target_graph3, middle_graph3)
955
+ # middle 2 source target
956
+ middle_2_source_graph3 = self.similarity_trans(middle_graph3, source_graph3)
957
+ middle_2_target_graph3 = self.similarity_trans(middle_graph3, target_graph3)
958
+
959
+ ## concat
960
+ # print(source_graph1.size(), target_2_source_graph1.size(), )
961
+ source_graph3 = torch.cat(
962
+ (source_graph3, target_2_source_graph3, target_2_source_graph3_v5,
963
+ middle_2_source_graph3, middle_2_source_graph3_v5), dim=-1)
964
+ source_graph3 = self.fc_graph_source.forward(source_graph3, relu=True)
965
+ # target
966
+ target_graph3 = torch.cat(
967
+ (target_graph3, source_2_target_graph3, source_2_target_graph3_v5,
968
+ middle_2_target_graph3, middle_2_target_graph3_v5), dim=-1)
969
+ target_graph3 = self.fc_graph_target.forward(target_graph3, relu=True)
970
+ # middle
971
+ middle_graph3 = torch.cat((middle_graph3, source_2_middle_graph3, source_2_middle_graph3_v5,
972
+ target_2_middle_graph3, target_2_middle_graph3_v5), dim=-1)
973
+ middle_graph3 = self.fc_graph_middle.forward(middle_graph3, relu=True)
974
+
975
+ return source_graph3, target_graph3, middle_graph3, x
976
+
977
+ def similarity_trans(self,source,target):
978
+ sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
979
+ sim = F.softmax(sim, dim=-1)
980
+ return torch.matmul(sim, source)
981
+
982
+ def bottom_forward_source(self, input, source_graph):
983
+ # print('input size')
984
+ # print(input.size())
985
+ # print(source_graph.size())
986
+ graph = self.source_graph_2_fea.forward(source_graph, input)
987
+ x = self.source_skip_conv(input)
988
+ x = x + graph
989
+ x = self.source_semantic(x)
990
+ return x
991
+
992
+ def bottom_forward_target(self, input, target_graph):
993
+ graph = self.target_graph_2_fea.forward(target_graph, input)
994
+ x = self.target_skip_conv(input)
995
+ x = x + graph
996
+ x = self.semantic(x)
997
+ return x
998
+
999
+ def bottom_forward_middle(self, input, target_graph):
1000
+ graph = self.middle_graph_2_fea.forward(target_graph, input)
1001
+ x = self.middle_skip_conv(input)
1002
+ x = x + graph
1003
+ x = self.middle_semantic(x)
1004
+ return x
1005
+
1006
+ def forward(self, input_source, input_target=None, input_middle=None, adj1_target=None, adj2_source=None,
1007
+ adj3_transfer_s2t=None, adj3_transfer_t2s=None, adj4_middle=None,adj5_transfer_s2m=None,
1008
+ adj6_transfer_t2m=None,adj5_transfer_m2s=None,adj6_transfer_m2t=None,):
1009
+
1010
+ if input_source is None and input_target is not None and input_middle is None:
1011
+ # target
1012
+ target_batch = input_target.size(0)
1013
+ input = input_target
1014
+
1015
+ source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target, adj2_source=adj2_source,
1016
+ adj3_transfer_s2t=adj3_transfer_s2t,
1017
+ adj3_transfer_t2s=adj3_transfer_t2s,
1018
+ adj4_middle=adj4_middle,
1019
+ adj5_transfer_s2m=adj5_transfer_s2m,
1020
+ adj6_transfer_t2m=adj6_transfer_t2m,
1021
+ adj5_transfer_m2s=adj5_transfer_m2s,
1022
+ adj6_transfer_m2t=adj6_transfer_m2t)
1023
+
1024
+ # source_x = self.bottom_forward_source(source_x, source_graph)
1025
+ target_x = self.bottom_forward_target(x, target_graph)
1026
+
1027
+ target_x = F.upsample(target_x, size=input.size()[2:], mode='bilinear', align_corners=True)
1028
+ return None, target_x, None
1029
+
1030
+ if input_source is not None and input_target is None and input_middle is None:
1031
+ # source
1032
+ source_batch = input_source.size(0)
1033
+ source_list = range(source_batch)
1034
+ input = input_source
1035
+
1036
+ source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
1037
+ adj2_source=adj2_source,
1038
+ adj3_transfer_s2t=adj3_transfer_s2t,
1039
+ adj3_transfer_t2s=adj3_transfer_t2s,
1040
+ adj4_middle=adj4_middle,
1041
+ adj5_transfer_s2m=adj5_transfer_s2m,
1042
+ adj6_transfer_t2m=adj6_transfer_t2m,
1043
+ adj5_transfer_m2s=adj5_transfer_m2s,
1044
+ adj6_transfer_m2t=adj6_transfer_m2t)
1045
+
1046
+ source_x = self.bottom_forward_source(x, source_graph)
1047
+ source_x = F.upsample(source_x, size=input.size()[2:], mode='bilinear', align_corners=True)
1048
+ return source_x, None, None
1049
+
1050
+ if input_middle is not None and input_source is None and input_target is None:
1051
+ # middle
1052
+ input = input_middle
1053
+
1054
+ source_graph, target_graph, middle_graph, x = self.top_forward(input, adj1_target=adj1_target,
1055
+ adj2_source=adj2_source,
1056
+ adj3_transfer_s2t=adj3_transfer_s2t,
1057
+ adj3_transfer_t2s=adj3_transfer_t2s,
1058
+ adj4_middle=adj4_middle,
1059
+ adj5_transfer_s2m=adj5_transfer_s2m,
1060
+ adj6_transfer_t2m=adj6_transfer_t2m,
1061
+ adj5_transfer_m2s=adj5_transfer_m2s,
1062
+ adj6_transfer_m2t=adj6_transfer_m2t)
1063
+
1064
+ middle_x = self.bottom_forward_middle(x, source_graph)
1065
+ middle_x = F.upsample(middle_x, size=input.size()[2:], mode='bilinear', align_corners=True)
1066
+ return None, None, middle_x
1067
+
1068
+
1069
+ if __name__ == '__main__':
1070
+ net = deeplab_xception_end2end_3d()
1071
+ net.freeze_totally_bn()
1072
+ img1 = torch.rand((1,3,128,128))
1073
+ img2 = torch.rand((1, 3, 128, 128))
1074
+ a1 = torch.ones((1,1,7,20))
1075
+ a2 = torch.ones((1,1,20,7))
1076
+ net.eval()
1077
+ net.forward(img1,img2,adj3_transfer_t2s=a2,adj3_transfer_s2t=a1)
TryYours-Virtual-Try-On/Graphonomy-master/networks/gcn.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn.parameter import Parameter
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from networks import graph
7
+ # import pdb
8
+
9
+ class GraphConvolution(nn.Module):
10
+
11
+ def __init__(self,in_features,out_features,bias=False):
12
+ super(GraphConvolution, self).__init__()
13
+ self.in_features = in_features
14
+ self.out_features = out_features
15
+ self.weight = Parameter(torch.FloatTensor(in_features,out_features))
16
+ if bias:
17
+ self.bias = Parameter(torch.FloatTensor(out_features))
18
+ else:
19
+ self.register_parameter('bias',None)
20
+ self.reset_parameters()
21
+
22
+ def reset_parameters(self):
23
+ # stdv = 1./math.sqrt(self.weight(1))
24
+ # self.weight.data.uniform_(-stdv,stdv)
25
+ torch.nn.init.xavier_uniform_(self.weight)
26
+ # if self.bias is not None:
27
+ # self.bias.data.uniform_(-stdv,stdv)
28
+
29
+ def forward(self, input,adj=None,relu=False):
30
+ support = torch.matmul(input, self.weight)
31
+ # print(support.size(),adj.size())
32
+ if adj is not None:
33
+ output = torch.matmul(adj, support)
34
+ else:
35
+ output = support
36
+ # print(output.size())
37
+ if self.bias is not None:
38
+ return output + self.bias
39
+ else:
40
+ if relu:
41
+ return F.relu(output)
42
+ else:
43
+ return output
44
+
45
+ def __repr__(self):
46
+ return self.__class__.__name__ + ' (' \
47
+ + str(self.in_features) + ' -> ' \
48
+ + str(self.out_features) + ')'
49
+
50
+ class Featuremaps_to_Graph(nn.Module):
51
+
52
+ def __init__(self,input_channels,hidden_layers,nodes=7):
53
+ super(Featuremaps_to_Graph, self).__init__()
54
+ self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes))
55
+ self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers))
56
+ self.reset_parameters()
57
+
58
+ def forward(self, input):
59
+ n,c,h,w = input.size()
60
+ # print('fea input',input.size())
61
+ input1 = input.view(n,c,h*w)
62
+ input1 = input1.transpose(1,2) # n x hw x c
63
+ # print('fea input1', input1.size())
64
+ ############## Feature maps to node ################
65
+ fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes
66
+ weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer
67
+ # softmax fea_node
68
+ fea_node = F.softmax(fea_node,dim=-1)
69
+ # print(fea_node.size(),weight_node.size())
70
+ graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node))
71
+ return graph_node # n x n_class x hidden_layer
72
+
73
+ def reset_parameters(self):
74
+ for ww in self.parameters():
75
+ torch.nn.init.xavier_uniform_(ww)
76
+ # if self.bias is not None:
77
+ # self.bias.data.uniform_(-stdv,stdv)
78
+
79
+ class Featuremaps_to_Graph_transfer(nn.Module):
80
+
81
+ def __init__(self,input_channels,hidden_layers,nodes=7, source_nodes=20):
82
+ super(Featuremaps_to_Graph_transfer, self).__init__()
83
+ self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes))
84
+ self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers))
85
+ self.pre_fea_transfer = nn.Sequential(*[nn.Linear(source_nodes, source_nodes),nn.LeakyReLU(True),
86
+ nn.Linear(source_nodes, nodes), nn.LeakyReLU(True)])
87
+ self.reset_parameters()
88
+
89
+ def forward(self, input, source_pre_fea):
90
+ self.pre_fea.data = self.pre_fea_learn(source_pre_fea)
91
+ n,c,h,w = input.size()
92
+ # print('fea input',input.size())
93
+ input1 = input.view(n,c,h*w)
94
+ input1 = input1.transpose(1,2) # n x hw x c
95
+ # print('fea input1', input1.size())
96
+ ############## Feature maps to node ################
97
+ fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes
98
+ weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer
99
+ # softmax fea_node
100
+ fea_node = F.softmax(fea_node,dim=1)
101
+ # print(fea_node.size(),weight_node.size())
102
+ graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node))
103
+ return graph_node # n x n_class x hidden_layer
104
+
105
+ def pre_fea_learn(self, input):
106
+ pre_fea = self.pre_fea_transfer.forward(input.unsqueeze(0)).squeeze(0)
107
+ return self.pre_fea.data + pre_fea
108
+
109
+ class Graph_to_Featuremaps(nn.Module):
110
+ # this is a special version
111
+ def __init__(self,input_channels,output_channels,hidden_layers,nodes=7):
112
+ super(Graph_to_Featuremaps, self).__init__()
113
+ self.node_fea = Parameter(torch.FloatTensor(input_channels+hidden_layers,1))
114
+ self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels))
115
+ self.reset_parameters()
116
+
117
+ def reset_parameters(self):
118
+ for ww in self.parameters():
119
+ torch.nn.init.xavier_uniform_(ww)
120
+
121
+ def forward(self, input, res_feature):
122
+ '''
123
+
124
+ :param input: 1 x batch x nodes x hidden_layer
125
+ :param res_feature: batch x channels x h x w
126
+ :return:
127
+ '''
128
+ batchi,channeli,hi,wi = res_feature.size()
129
+ # print(res_feature.size())
130
+ # print(input.size())
131
+ try:
132
+ _,batch,nodes,hidden = input.size()
133
+ except:
134
+ # print(input.size())
135
+ input = input.unsqueeze(0)
136
+ _,batch, nodes, hidden = input.size()
137
+
138
+ assert batch == batchi
139
+ input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden)
140
+ res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2)
141
+ res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli)
142
+ new_fea = torch.cat((res_feature_after_view1,input1),dim=3)
143
+
144
+ # print(self.node_fea.size(),new_fea.size())
145
+ new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1
146
+ new_weight = torch.matmul(input, self.weight) # batch x node x channel
147
+ new_node = new_node.view(batch, hi*wi, nodes)
148
+ # 0721
149
+ new_node = F.softmax(new_node, dim=-1)
150
+ #
151
+ feature_out = torch.matmul(new_node,new_weight)
152
+ # print(feature_out.size())
153
+ feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size())
154
+ return F.relu(feature_out)
155
+
156
+ class Graph_to_Featuremaps_savemem(nn.Module):
157
+ # this is a special version for saving gpu memory. The process is same as Graph_to_Featuremaps.
158
+ def __init__(self, input_channels, output_channels, hidden_layers, nodes=7):
159
+ super(Graph_to_Featuremaps_savemem, self).__init__()
160
+ self.node_fea_for_res = Parameter(torch.FloatTensor(input_channels, 1))
161
+ self.node_fea_for_hidden = Parameter(torch.FloatTensor(hidden_layers, 1))
162
+ self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels))
163
+ self.reset_parameters()
164
+
165
+ def reset_parameters(self):
166
+ for ww in self.parameters():
167
+ torch.nn.init.xavier_uniform_(ww)
168
+
169
+ def forward(self, input, res_feature):
170
+ '''
171
+
172
+ :param input: 1 x batch x nodes x hidden_layer
173
+ :param res_feature: batch x channels x h x w
174
+ :return:
175
+ '''
176
+ batchi,channeli,hi,wi = res_feature.size()
177
+ # print(res_feature.size())
178
+ # print(input.size())
179
+ try:
180
+ _,batch,nodes,hidden = input.size()
181
+ except:
182
+ # print(input.size())
183
+ input = input.unsqueeze(0)
184
+ _,batch, nodes, hidden = input.size()
185
+
186
+ assert batch == batchi
187
+ input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden)
188
+ res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2)
189
+ res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli)
190
+ # new_fea = torch.cat((res_feature_after_view1,input1),dim=3)
191
+ ## sim
192
+ new_node1 = torch.matmul(res_feature_after_view1, self.node_fea_for_res)
193
+ new_node2 = torch.matmul(input1, self.node_fea_for_hidden)
194
+ new_node = new_node1 + new_node2
195
+ ## sim end
196
+ # print(self.node_fea.size(),new_fea.size())
197
+ # new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1
198
+ new_weight = torch.matmul(input, self.weight) # batch x node x channel
199
+ new_node = new_node.view(batch, hi*wi, nodes)
200
+ # 0721
201
+ new_node = F.softmax(new_node, dim=-1)
202
+ #
203
+ feature_out = torch.matmul(new_node,new_weight)
204
+ # print(feature_out.size())
205
+ feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size())
206
+ return F.relu(feature_out)
207
+
208
+
209
+ class Graph_trans(nn.Module):
210
+
211
+ def __init__(self,in_features,out_features,begin_nodes=7,end_nodes=2,bias=False,adj=None):
212
+ super(Graph_trans, self).__init__()
213
+ self.in_features = in_features
214
+ self.out_features = out_features
215
+ self.weight = Parameter(torch.FloatTensor(in_features,out_features))
216
+ if adj is not None:
217
+ h,w = adj.size()
218
+ assert (h == end_nodes) and (w == begin_nodes)
219
+ self.adj = torch.autograd.Variable(adj,requires_grad=False)
220
+ else:
221
+ self.adj = Parameter(torch.FloatTensor(end_nodes,begin_nodes))
222
+ if bias:
223
+ self.bias = Parameter(torch.FloatTensor(out_features))
224
+ else:
225
+ self.register_parameter('bias',None)
226
+ # self.reset_parameters()
227
+
228
+ def reset_parameters(self):
229
+ # stdv = 1./math.sqrt(self.weight(1))
230
+ # self.weight.data.uniform_(-stdv,stdv)
231
+ torch.nn.init.xavier_uniform_(self.weight)
232
+ # if self.bias is not None:
233
+ # self.bias.data.uniform_(-stdv,stdv)
234
+
235
+ def forward(self, input, relu=False, adj_return=False, adj=None):
236
+ support = torch.matmul(input,self.weight)
237
+ # print(support.size(),self.adj.size())
238
+ if adj is None:
239
+ adj = self.adj
240
+ adj1 = self.norm_trans_adj(adj)
241
+ output = torch.matmul(adj1,support)
242
+ if adj_return:
243
+ output1 = F.normalize(output,p=2,dim=-1)
244
+ self.adj_mat = torch.matmul(output1,output1.transpose(-2,-1))
245
+ if self.bias is not None:
246
+ return output + self.bias
247
+ else:
248
+ if relu:
249
+ return F.relu(output)
250
+ else:
251
+ return output
252
+
253
+ def get_adj_mat(self):
254
+ adj = graph.normalize_adj_torch(F.relu(self.adj_mat))
255
+ return adj
256
+
257
+ def get_encode_adj(self):
258
+ return self.adj
259
+
260
+ def norm_trans_adj(self,adj): # maybe can use softmax
261
+ adj = F.relu(adj)
262
+ r = F.softmax(adj,dim=-1)
263
+ # print(adj.size())
264
+ # row_sum = adj.sum(-1).unsqueeze(-1)
265
+ # d_mat = row_sum.expand(adj.size())
266
+ # r = torch.div(row_sum,d_mat)
267
+ # r[torch.isnan(r)] = 0
268
+
269
+ return r
270
+
271
+
272
+ if __name__ == '__main__':
273
+
274
+ graph = torch.randn((7,128))
275
+ en = GraphConvolution(128,128)
276
+ a = en.forward(graph)
277
+ print(a)
278
+ # a = en.forward(graph,pred)
279
+ # print(a.size())
TryYours-Virtual-Try-On/Graphonomy-master/networks/graph.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle as pkl
3
+ import networkx as nx
4
+ import scipy.sparse as sp
5
+ import torch
6
+
7
+ pascal_graph = {0:[0],
8
+ 1:[1, 2],
9
+ 2:[1, 2, 3, 5],
10
+ 3:[2, 3, 4],
11
+ 4:[3, 4],
12
+ 5:[2, 5, 6],
13
+ 6:[5, 6]}
14
+
15
+ cihp_graph = {0: [],
16
+ 1: [2, 13],
17
+ 2: [1, 13],
18
+ 3: [14, 15],
19
+ 4: [13],
20
+ 5: [6, 7, 9, 10, 11, 12, 14, 15],
21
+ 6: [5, 7, 10, 11, 14, 15, 16, 17],
22
+ 7: [5, 6, 9, 10, 11, 12, 14, 15],
23
+ 8: [16, 17, 18, 19],
24
+ 9: [5, 7, 10, 16, 17, 18, 19],
25
+ 10:[5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17],
26
+ 11:[5, 6, 7, 10, 13],
27
+ 12:[5, 7, 10, 16, 17],
28
+ 13:[1, 2, 4, 10, 11],
29
+ 14:[3, 5, 6, 7, 10],
30
+ 15:[3, 5, 6, 7, 10],
31
+ 16:[6, 8, 9, 10, 12, 18],
32
+ 17:[6, 8, 9, 10, 12, 19],
33
+ 18:[8, 9, 16],
34
+ 19:[8, 9, 17]}
35
+
36
+ atr_graph = {0: [],
37
+ 1: [2, 11],
38
+ 2: [1, 11],
39
+ 3: [11],
40
+ 4: [5, 6, 7, 11, 14, 15, 17],
41
+ 5: [4, 6, 7, 8, 12, 13],
42
+ 6: [4,5,7,8,9,10,12,13],
43
+ 7: [4,11,12,13,14,15],
44
+ 8: [5,6],
45
+ 9: [6, 12],
46
+ 10:[6, 13],
47
+ 11:[1,2,3,4,7,14,15,17],
48
+ 12:[5,6,7,9],
49
+ 13:[5,6,7,10],
50
+ 14:[4,7,11,16],
51
+ 15:[4,7,11,16],
52
+ 16:[14,15],
53
+ 17:[4,11],
54
+ }
55
+
56
+ cihp2pascal_adj = np.array([[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
57
+ [0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
58
+ [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
59
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
60
+ [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
61
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
62
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
63
+
64
+ cihp2pascal_nlp_adj = \
65
+ np.array([[ 1., 0.35333052, 0.32727194, 0.17418084, 0.18757584,
66
+ 0.40608522, 0.37503981, 0.35448462, 0.22598555, 0.23893579,
67
+ 0.33064262, 0.28923404, 0.27986573, 0.4211553 , 0.36915778,
68
+ 0.41377746, 0.32485771, 0.37248222, 0.36865639, 0.41500332],
69
+ [ 0.39615879, 0.46201529, 0.52321467, 0.30826114, 0.25669527,
70
+ 0.54747773, 0.3670523 , 0.3901983 , 0.27519473, 0.3433325 ,
71
+ 0.52728509, 0.32771333, 0.34819325, 0.63882953, 0.68042925,
72
+ 0.69368576, 0.63395791, 0.65344337, 0.59538781, 0.6071375 ],
73
+ [ 0.16373166, 0.21663339, 0.3053872 , 0.28377612, 0.1372435 ,
74
+ 0.4448808 , 0.29479995, 0.31092595, 0.22703953, 0.33983576,
75
+ 0.75778818, 0.2619818 , 0.37069392, 0.35184867, 0.49877512,
76
+ 0.49979437, 0.51853277, 0.52517541, 0.32517741, 0.32377309],
77
+ [ 0.32687232, 0.38482461, 0.37693463, 0.41610834, 0.20415749,
78
+ 0.76749079, 0.35139853, 0.3787411 , 0.28411737, 0.35155421,
79
+ 0.58792618, 0.31141718, 0.40585111, 0.51189218, 0.82042737,
80
+ 0.8342413 , 0.70732188, 0.72752501, 0.60327325, 0.61431337],
81
+ [ 0.34069369, 0.34817292, 0.37525998, 0.36497069, 0.17841617,
82
+ 0.69746208, 0.31731463, 0.34628951, 0.25167277, 0.32072379,
83
+ 0.56711286, 0.24894776, 0.37000453, 0.52600859, 0.82483993,
84
+ 0.84966274, 0.7033991 , 0.73449378, 0.56649608, 0.58888791],
85
+ [ 0.28477487, 0.35139564, 0.42742352, 0.41664321, 0.20004676,
86
+ 0.78566833, 0.42237487, 0.41048549, 0.37933812, 0.46542516,
87
+ 0.62444759, 0.3274493 , 0.49466009, 0.49314658, 0.71244233,
88
+ 0.71497003, 0.8234787 , 0.83566589, 0.62597135, 0.62626812],
89
+ [ 0.3011378 , 0.31775977, 0.42922647, 0.36896257, 0.17597556,
90
+ 0.72214655, 0.39162804, 0.38137872, 0.34980296, 0.43818419,
91
+ 0.60879174, 0.26762545, 0.46271161, 0.51150476, 0.72318109,
92
+ 0.73678399, 0.82620388, 0.84942166, 0.5943811 , 0.60607602]])
93
+
94
+ pascal2atr_nlp_adj = \
95
+ np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522,
96
+ 0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639,
97
+ 0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778,
98
+ 0.41377746, 0.32006291, 0.28923404],
99
+ [ 0.39615879, 0.46201529, 0.52321467, 0.25669527, 0.54747773,
100
+ 0.34819325, 0.3433325 , 0.26603942, 0.45162929, 0.59538781,
101
+ 0.6071375 , 0.63882953, 0.63395791, 0.65344337, 0.68042925,
102
+ 0.69368576, 0.44354613, 0.32771333],
103
+ [ 0.16373166, 0.21663339, 0.3053872 , 0.1372435 , 0.4448808 ,
104
+ 0.37069392, 0.33983576, 0.26563416, 0.35443504, 0.32517741,
105
+ 0.32377309, 0.35184867, 0.51853277, 0.52517541, 0.49877512,
106
+ 0.49979437, 0.21750868, 0.2619818 ],
107
+ [ 0.32687232, 0.38482461, 0.37693463, 0.20415749, 0.76749079,
108
+ 0.40585111, 0.35155421, 0.28271333, 0.52684576, 0.60327325,
109
+ 0.61431337, 0.51189218, 0.70732188, 0.72752501, 0.82042737,
110
+ 0.8342413 , 0.40137029, 0.31141718],
111
+ [ 0.34069369, 0.34817292, 0.37525998, 0.17841617, 0.69746208,
112
+ 0.37000453, 0.32072379, 0.27268885, 0.47426719, 0.56649608,
113
+ 0.58888791, 0.52600859, 0.7033991 , 0.73449378, 0.82483993,
114
+ 0.84966274, 0.37830796, 0.24894776],
115
+ [ 0.28477487, 0.35139564, 0.42742352, 0.20004676, 0.78566833,
116
+ 0.49466009, 0.46542516, 0.32662614, 0.55780359, 0.62597135,
117
+ 0.62626812, 0.49314658, 0.8234787 , 0.83566589, 0.71244233,
118
+ 0.71497003, 0.41223219, 0.3274493 ],
119
+ [ 0.3011378 , 0.31775977, 0.42922647, 0.17597556, 0.72214655,
120
+ 0.46271161, 0.43818419, 0.3192333 , 0.50979216, 0.5943811 ,
121
+ 0.60607602, 0.51150476, 0.82620388, 0.84942166, 0.72318109,
122
+ 0.73678399, 0.39259827, 0.26762545]])
123
+
124
+ cihp2atr_nlp_adj = np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522,
125
+ 0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639,
126
+ 0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778,
127
+ 0.41377746, 0.32006291, 0.28923404],
128
+ [ 0.35333052, 1. , 0.39206695, 0.42143438, 0.4736689 ,
129
+ 0.47139544, 0.51999208, 0.38354847, 0.45628529, 0.46514124,
130
+ 0.50083501, 0.4310595 , 0.39371443, 0.4319752 , 0.42938598,
131
+ 0.46384034, 0.44833757, 0.6153155 ],
132
+ [ 0.32727194, 0.39206695, 1. , 0.32836702, 0.52603065,
133
+ 0.39543695, 0.3622627 , 0.43575346, 0.33866223, 0.45202552,
134
+ 0.48421 , 0.53669903, 0.47266611, 0.50925436, 0.42286557,
135
+ 0.45403656, 0.37221304, 0.40999322],
136
+ [ 0.17418084, 0.46892601, 0.25774838, 0.31816231, 0.39330317,
137
+ 0.34218382, 0.48253904, 0.22084125, 0.41335728, 0.52437572,
138
+ 0.5191713 , 0.33576117, 0.44230914, 0.44250678, 0.44330833,
139
+ 0.43887264, 0.50693611, 0.39278795],
140
+ [ 0.18757584, 0.42143438, 0.32836702, 1. , 0.35030067,
141
+ 0.30110947, 0.41055555, 0.34338879, 0.34336307, 0.37704433,
142
+ 0.38810141, 0.34702081, 0.24171562, 0.25433078, 0.24696241,
143
+ 0.2570884 , 0.4465962 , 0.45263213],
144
+ [ 0.40608522, 0.4736689 , 0.52603065, 0.35030067, 1. ,
145
+ 0.54372584, 0.58300258, 0.56674191, 0.555266 , 0.66599594,
146
+ 0.68567555, 0.55716359, 0.62997328, 0.65638548, 0.61219615,
147
+ 0.63183318, 0.54464151, 0.44293752],
148
+ [ 0.37503981, 0.50675565, 0.4761106 , 0.37561813, 0.60419403,
149
+ 0.77912403, 0.64595517, 0.85939662, 0.46037144, 0.52348817,
150
+ 0.55875094, 0.37741886, 0.455671 , 0.49434392, 0.38479954,
151
+ 0.41804074, 0.47285709, 0.57236283],
152
+ [ 0.35448462, 0.50576632, 0.51030446, 0.35841033, 0.55106903,
153
+ 0.50257274, 0.52591451, 0.4283053 , 0.39991808, 0.42327211,
154
+ 0.42853819, 0.42071825, 0.41240559, 0.42259136, 0.38125352,
155
+ 0.3868255 , 0.47604934, 0.51811717],
156
+ [ 0.22598555, 0.5053299 , 0.36301185, 0.38002282, 0.49700941,
157
+ 0.45625243, 0.62876479, 0.4112051 , 0.33944371, 0.48322639,
158
+ 0.50318714, 0.29207815, 0.38801966, 0.41119094, 0.29199072,
159
+ 0.31021029, 0.41594871, 0.54961962],
160
+ [ 0.23893579, 0.51999208, 0.3622627 , 0.41055555, 0.58300258,
161
+ 0.68874251, 1. , 0.56977937, 0.49918447, 0.48484363,
162
+ 0.51615925, 0.41222306, 0.49535971, 0.53134951, 0.3807616 ,
163
+ 0.41050298, 0.48675801, 0.51112664],
164
+ [ 0.33064262, 0.306412 , 0.60679935, 0.25592294, 0.58738706,
165
+ 0.40379627, 0.39679161, 0.33618385, 0.39235148, 0.45474013,
166
+ 0.4648476 , 0.59306762, 0.58976007, 0.60778661, 0.55400397,
167
+ 0.56551297, 0.3698029 , 0.33860535],
168
+ [ 0.28923404, 0.6153155 , 0.40999322, 0.45263213, 0.44293752,
169
+ 0.60359359, 0.51112664, 0.46578181, 0.45656936, 0.38142307,
170
+ 0.38525582, 0.33327223, 0.35360175, 0.36156453, 0.3384992 ,
171
+ 0.34261229, 0.49297863, 1. ],
172
+ [ 0.27986573, 0.47139544, 0.39543695, 0.30110947, 0.54372584,
173
+ 1. , 0.68874251, 0.67765588, 0.48690078, 0.44010641,
174
+ 0.44921156, 0.32321099, 0.48311542, 0.4982002 , 0.39378102,
175
+ 0.40297733, 0.45309735, 0.60359359],
176
+ [ 0.4211553 , 0.4310595 , 0.53669903, 0.34702081, 0.55716359,
177
+ 0.32321099, 0.41222306, 0.25721705, 0.36633509, 0.5397475 ,
178
+ 0.56429928, 1. , 0.55796926, 0.58842844, 0.57930828,
179
+ 0.60410597, 0.41615326, 0.33327223],
180
+ [ 0.36915778, 0.42938598, 0.42286557, 0.24696241, 0.61219615,
181
+ 0.39378102, 0.3807616 , 0.28089866, 0.48450394, 0.77400821,
182
+ 0.68813814, 0.57930828, 0.8856886 , 0.81673412, 1. ,
183
+ 0.92279623, 0.46969152, 0.3384992 ],
184
+ [ 0.41377746, 0.46384034, 0.45403656, 0.2570884 , 0.63183318,
185
+ 0.40297733, 0.41050298, 0.332879 , 0.48799542, 0.69231828,
186
+ 0.77015091, 0.60410597, 0.79788484, 0.88232104, 0.92279623,
187
+ 1. , 0.45685017, 0.34261229],
188
+ [ 0.32485771, 0.39371443, 0.47266611, 0.24171562, 0.62997328,
189
+ 0.48311542, 0.49535971, 0.32477932, 0.51486622, 0.79353556,
190
+ 0.69768738, 0.55796926, 1. , 0.92373745, 0.8856886 ,
191
+ 0.79788484, 0.47883134, 0.35360175],
192
+ [ 0.37248222, 0.4319752 , 0.50925436, 0.25433078, 0.65638548,
193
+ 0.4982002 , 0.53134951, 0.38057074, 0.52403969, 0.72035243,
194
+ 0.78711147, 0.58842844, 0.92373745, 1. , 0.81673412,
195
+ 0.88232104, 0.47109935, 0.36156453],
196
+ [ 0.36865639, 0.46514124, 0.45202552, 0.37704433, 0.66599594,
197
+ 0.44010641, 0.48484363, 0.39636574, 0.50175258, 1. ,
198
+ 0.91320249, 0.5397475 , 0.79353556, 0.72035243, 0.77400821,
199
+ 0.69231828, 0.59087008, 0.38142307],
200
+ [ 0.41500332, 0.50083501, 0.48421, 0.38810141, 0.68567555,
201
+ 0.44921156, 0.51615925, 0.45156472, 0.50438158, 0.91320249,
202
+ 1., 0.56429928, 0.69768738, 0.78711147, 0.68813814,
203
+ 0.77015091, 0.57698754, 0.38525582]])
204
+
205
+
206
+
207
+ def normalize_adj(adj):
208
+ """Symmetrically normalize adjacency matrix."""
209
+ adj = sp.coo_matrix(adj)
210
+ rowsum = np.array(adj.sum(1))
211
+ d_inv_sqrt = np.power(rowsum, -0.5).flatten()
212
+ d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
213
+ d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
214
+ return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
215
+
216
+ def preprocess_adj(adj):
217
+ """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
218
+ adj = nx.adjacency_matrix(nx.from_dict_of_lists(adj)) # return a adjacency matrix of adj ( type is numpy)
219
+ adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) #
220
+ # return sparse_to_tuple(adj_normalized)
221
+ return adj_normalized.todense()
222
+
223
+ def row_norm(inputs):
224
+ outputs = []
225
+ for x in inputs:
226
+ xsum = x.sum()
227
+ x = x / xsum
228
+ outputs.append(x)
229
+ return outputs
230
+
231
+
232
+ def normalize_adj_torch(adj):
233
+ # print(adj.size())
234
+ if len(adj.size()) == 4:
235
+ new_r = torch.zeros(adj.size()).type_as(adj)
236
+ for i in range(adj.size(1)):
237
+ adj_item = adj[0,i]
238
+ rowsum = adj_item.sum(1)
239
+ d_inv_sqrt = rowsum.pow_(-0.5)
240
+ d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
241
+ d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
242
+ r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt)
243
+ new_r[0,i,...] = r
244
+ return new_r
245
+ rowsum = adj.sum(1)
246
+ d_inv_sqrt = rowsum.pow_(-0.5)
247
+ d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
248
+ d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
249
+ r = torch.matmul(torch.matmul(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)
250
+ return r
251
+
252
+ # def row_norm(adj):
253
+
254
+
255
+
256
+
257
+ if __name__ == '__main__':
258
+ a= row_norm(cihp2pascal_adj)
259
+ print(a)
260
+ print(cihp2pascal_adj)
261
+ # print(a.shape)