Spaces:
Runtime error
Runtime error
Commit
·
5c718d1
0
Parent(s):
first commit
Browse files- .gitignore +86 -0
- Dockerfile +15 -0
- README.md +40 -0
- biomap/.gitignore +6 -0
- biomap/.private-key.json +12 -0
- biomap/app.py +110 -0
- biomap/configs/my_train_config.yml +197 -0
- biomap/data.py +584 -0
- biomap/dataset_generator/__init__.py +6 -0
- biomap/dataset_generator/data_loader.py +356 -0
- biomap/dino/utils.py +619 -0
- biomap/dino/vision_transformer.py +314 -0
- biomap/helper.py +179 -0
- biomap/inference.py +261 -0
- biomap/label.png +0 -0
- biomap/model.py +453 -0
- biomap/modules.py +472 -0
- biomap/output/img.png +0 -0
- biomap/output/img_6.png +0 -0
- biomap/output/label.png +0 -0
- biomap/output/labeled_img.png +0 -0
- biomap/plot_functions.py +778 -0
- biomap/train.py +267 -0
- biomap/unet.py +80 -0
- biomap/utils.py +390 -0
- biomap/utils_gee.py +157 -0
- poetry.lock +1625 -0
- pyproject.toml +31 -0
- requirements.txt +133 -0
.gitignore
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
|
| 5 |
+
# C extensions
|
| 6 |
+
*.so
|
| 7 |
+
|
| 8 |
+
# Distribution / packaging
|
| 9 |
+
.Python
|
| 10 |
+
env/
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
|
| 26 |
+
# PyInstaller
|
| 27 |
+
# Usually these files are written by a python script from a template
|
| 28 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 29 |
+
*.manifest
|
| 30 |
+
*.spec
|
| 31 |
+
|
| 32 |
+
# Installer logs
|
| 33 |
+
pip-log.txt
|
| 34 |
+
pip-delete-this-directory.txt
|
| 35 |
+
|
| 36 |
+
# Unit test / coverage reports
|
| 37 |
+
htmlcov/
|
| 38 |
+
.tox/
|
| 39 |
+
.coverage
|
| 40 |
+
.coverage.*
|
| 41 |
+
.cache
|
| 42 |
+
nosetests.xml
|
| 43 |
+
coverage.xml
|
| 44 |
+
*.cover
|
| 45 |
+
|
| 46 |
+
# Translations
|
| 47 |
+
*.mo
|
| 48 |
+
*.pot
|
| 49 |
+
|
| 50 |
+
# Django stuff:
|
| 51 |
+
*.log
|
| 52 |
+
|
| 53 |
+
# Sphinx documentation
|
| 54 |
+
docs/_build/
|
| 55 |
+
|
| 56 |
+
# PyBuilder
|
| 57 |
+
target/
|
| 58 |
+
|
| 59 |
+
# DotEnv configuration
|
| 60 |
+
.env
|
| 61 |
+
|
| 62 |
+
# Database
|
| 63 |
+
*.db
|
| 64 |
+
*.rdb
|
| 65 |
+
|
| 66 |
+
# Pycharm
|
| 67 |
+
.idea
|
| 68 |
+
|
| 69 |
+
# VS Code
|
| 70 |
+
.vscode/
|
| 71 |
+
|
| 72 |
+
# Spyder
|
| 73 |
+
.spyproject/
|
| 74 |
+
|
| 75 |
+
# Jupyter NB Checkpoints
|
| 76 |
+
.ipynb_checkpoints/
|
| 77 |
+
|
| 78 |
+
# Mac OS-specific storage files
|
| 79 |
+
.DS_Store
|
| 80 |
+
|
| 81 |
+
# vim
|
| 82 |
+
*.swp
|
| 83 |
+
*.swo
|
| 84 |
+
|
| 85 |
+
# Mypy cache
|
| 86 |
+
.mypy_cache/
|
Dockerfile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9
|
| 2 |
+
COPY requirements.txt /app/requirements.txt
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
RUN pip install seaborn
|
| 5 |
+
RUN pip install gradio
|
| 6 |
+
RUN pip install datetime
|
| 7 |
+
RUN pip install numpy
|
| 8 |
+
RUN pip install opencv-python
|
| 9 |
+
RUN apt-get update
|
| 10 |
+
RUN apt-get install ffmpeg libsm6 libxext6 -y
|
| 11 |
+
RUN pip install -r requirements.txt
|
| 12 |
+
COPY . /app
|
| 13 |
+
# EXPOSE 7860
|
| 14 |
+
CMD python app.py
|
| 15 |
+
# hello world
|
README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Welcome to the project inno-satellite-images-segmentation-gan
|
| 2 |
+

|
| 3 |
+
|
| 4 |
+
- **Project name**: inno-satellite-images-segmentation-gan
|
| 5 |
+
- **Library name**: library
|
| 6 |
+
- **Authors**: Ekimetrics
|
| 7 |
+
- **Description**: Segmenting satellite images in a large scale is challenging because grondtruth labels are spurious for medium resolution images (Sentinel 2). We want to improve our algorithm either with data augmentation from a GAN, or to correct or adjust Corine labels.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
## Project Structure
|
| 12 |
+
```
|
| 13 |
+
- library/ # Your python library
|
| 14 |
+
- data/
|
| 15 |
+
- raw/
|
| 16 |
+
- processed/
|
| 17 |
+
- docs/
|
| 18 |
+
- tests/ # Where goes each unitary test in your folder
|
| 19 |
+
- scripts/ # Where each automation script will go
|
| 20 |
+
- requirements.txt # Where you should put the libraries version used in your library
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## Branch strategy
|
| 25 |
+
TBD
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Ethics checklist
|
| 29 |
+
TBD
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
## Starter package
|
| 34 |
+
This project has been created using the Ekimetrics Python Starter Package to enforce best coding practices, reusability and industrialization. <br>
|
| 35 |
+
If you have any questions please reach out to the inno team and [Théo Alves Da Costa](mailto:[email protected])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
biomap/.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#wsl
|
| 2 |
+
*.Zone.Identifier
|
| 3 |
+
|
| 4 |
+
#python
|
| 5 |
+
*__pycache__
|
| 6 |
+
|
biomap/.private-key.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"type": "service_account",
|
| 3 |
+
"project_id": "cvimg-377115",
|
| 4 |
+
"private_key_id": "a162152bd26f4bcc287c44b130109892b5517875",
|
| 5 |
+
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDCr/zwOTwyVVdF\nk1cKcdju9jeDWceVJ/zj73b5IS8IbQJbFGWSjaL6Ft1HfJ3rSdGbj+Xy26jY9OFJ\n4rIhpY0M0cCWpYo+gqS9p4JAL6lHqZvSnkRTglpx9QYOT8o9ibCWhMVVAPH71QZ/\n4BEfGC6s2+zdEn+cbCkGIqLLhZTq655kDOaGSycwV/bk+TOLI/An4gjMoEIimhsD\nS6TRmqZnQGoI6m6aj3xPZGVMkid3I37h+BOC64YjKeXnpAhTQ4LbpQz1BxvIKDt6\ncJ1FBOmdSvUC+dLEGMN3yijpJZ74nXUnSVbwYYt3K8Kz2PtmgDYvswER53NlrMoW\n1AF9ImDFAgMBAAECggEAASO7IXWBo90d6CiEdhQF9uFy2S5d9ol9N6EBl+JHOuwB\nwne4tSZJ3jT/Rus7wvX67tXI9dgf3fIgjsv92NLnwEn1Wq/xQppMm8iyK1DKU3vH\n8xrvf8iG048ojGqQXqf0ZEUoWd+/YDGZ2qNuZjmVgwZKwF2h2pcnQ25uIvWdYHrb\n3XhYLDAROVTTtyscYcl8UKmAZ35moVVBQxdakGunYg6o/s6rESRbc+gCyqHR5v+r\nCl3Z4XEKDdukIVI72Ybk0F8eZpQztN97uzK/zm9jl4NmAPXrnWLEwuJdwdm1cWUF\n/LTTuNPmRzCm7IGUpkx0AKEs6s0BRbJbwlZaj4QVJwKBgQDjb2rSO6kRLRHyv+4w\ny/OLmqOrMY7fpSCj0mH41GhiZUaoZqhDznmuhqEjo1kVipfuW94Fn5NsGmWpbmEC\nJlObUEg1umX/ceOJrtRdY3AQMSQXR6u7oc2mTgj3Opd0V1L1Lopj4Ijj43ARg/fU\nu4RnrCGHcXXzT2LCchY0ZhLg3wKBgQDbI6bzt/RNW8+IGKCvLLi41bxM/9r83GNO\nQI4a6yTT09N11okjP9h00JKYBgU3fYivA1aBloFB4kOYaBzomfWSZHEyyFWCr9y0\ndGyIDbfUaI/jFx2CaKomLnPDF5LA3IWHAsTRZ/c1JGhiOUseEq/TR0cJAo69kgf0\nkVmoGjo+2wKBgQCo7crkGJg9P8LDEbgz2mktWlETCR5cE2SpCczna62U2DChSI7W\nvng3H5x0whGbJHQxAV9pwdtYQksci/XWCO20wO7BqY+1KrydOZRXQVKtVDLAb+Wo\n2kfLrM6QA58XNP1TS5xTDyXeTsKg3+qmwhlYf8vvtGCttltenirMBL0k9QKBgFpL\nanNqDOQDPJQbcbo8dzDSAPDJS/Z86P5JY0R8N4SA99TKPV+k4w/fEUhK0sN2mmdi\nvLZQyZnYHXojDCZbqfBUKsB+A54B0LMadc3puSFwpDkyQRqG/fUVluWARRvqwapL\n3cVbTWU8RzaR3P3bPU+VQxPXVfGOxnBjo8m8ZNuZAoGBANTC20T9rZ9Won9FCbi3\nSMkGY59smx19CdytQ2rjGFOEeAVMVotP5viXFuKfv5g2E/kvyJUjuOoulA3dxddN\nQzXnOIT3dlyBmvXkHJHUIKiidyuX4JqQFdPTAmkt6KaTceRNb7VN1OqIk1AJ1SDb\nkGxerLg4WuGfSqOIV0Wk4cLI\n-----END PRIVATE KEY-----\n",
|
| 6 |
+
"client_email": "[email protected]",
|
| 7 |
+
"client_id": "115144831673857322488",
|
| 8 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 9 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 10 |
+
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
| 11 |
+
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/cvimg-355%40cvimg-377115.iam.gserviceaccount.com"
|
| 12 |
+
}
|
biomap/app.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from plot_functions import *
|
| 2 |
+
import hydra
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from model import LitUnsupervisedSegmenter
|
| 6 |
+
from helper import inference_on_location_and_month, inference_on_location
|
| 7 |
+
from plot_functions import segment_region
|
| 8 |
+
|
| 9 |
+
from functools import partial
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
import geopandas as gpd
|
| 14 |
+
mapbox_access_token = "pk.eyJ1IjoiamVyZW15LWVraW1ldHJpY3MiLCJhIjoiY2xrNjBwNGU2MDRhMjNqbWw0YTJrbnpvNCJ9.poVyIzhJuJmD6ffrL9lm2w"
|
| 15 |
+
geo_df = gpd.read_file(gpd.datasets.get_path('naturalearth_cities'))
|
| 16 |
+
|
| 17 |
+
def get_geomap(long, lat ):
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
fig = go.Figure(go.Scattermapbox(
|
| 21 |
+
lat=geo_df.geometry.y,
|
| 22 |
+
lon=geo_df.geometry.x,
|
| 23 |
+
mode='markers',
|
| 24 |
+
marker=go.scattermapbox.Marker(
|
| 25 |
+
size=14
|
| 26 |
+
),
|
| 27 |
+
text=geo_df.name,
|
| 28 |
+
))
|
| 29 |
+
|
| 30 |
+
fig.add_trace(go.Scattermapbox(lat=[lat],
|
| 31 |
+
lon=[long],
|
| 32 |
+
mode='markers',
|
| 33 |
+
marker=go.scattermapbox.Marker(
|
| 34 |
+
size=14
|
| 35 |
+
),
|
| 36 |
+
marker_color="green",
|
| 37 |
+
text=['Actual position']))
|
| 38 |
+
|
| 39 |
+
fig.update_layout(
|
| 40 |
+
showlegend=False,
|
| 41 |
+
hovermode='closest',
|
| 42 |
+
mapbox=dict(
|
| 43 |
+
accesstoken=mapbox_access_token,
|
| 44 |
+
center=go.layout.mapbox.Center(
|
| 45 |
+
lat=lat,
|
| 46 |
+
lon=long
|
| 47 |
+
),
|
| 48 |
+
zoom=3
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
return fig
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
|
| 59 |
+
# Initialize hydra with configs
|
| 60 |
+
#hydra.initialize(config_path="configs", job_name="corine")
|
| 61 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
| 62 |
+
logging.info(f"config : {cfg}")
|
| 63 |
+
# Load the model
|
| 64 |
+
|
| 65 |
+
nbclasses = cfg.dir_dataset_n_classes
|
| 66 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
| 67 |
+
logging.info(f"Model Initialiazed")
|
| 68 |
+
|
| 69 |
+
model_path = "checkpoint/model/model.pt"
|
| 70 |
+
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
| 71 |
+
logging.info(f"Model weights Loaded")
|
| 72 |
+
model.load_state_dict(saved_state_dict)
|
| 73 |
+
logging.info(f"Model Loaded")
|
| 74 |
+
# css=".VIDEO video{height: 100%;width:50%;margin:auto};.VIDEO{height: 50%;};.svelte-1vnmhm4{height:auto}"
|
| 75 |
+
with gr.Blocks() as demo:
|
| 76 |
+
gr.Markdown("Estimate Biodiversity in the world.")
|
| 77 |
+
with gr.Tab("Single Image"):
|
| 78 |
+
with gr.Row():
|
| 79 |
+
input_map = gr.Plot().style()
|
| 80 |
+
with gr.Column():
|
| 81 |
+
input_latitude = gr.Number(label="lattitude", value=2.98)
|
| 82 |
+
input_longitude = gr.Number(label="longitude", value=48.81)
|
| 83 |
+
input_date = gr.Textbox(label="start_date", value="2020-03-20")
|
| 84 |
+
|
| 85 |
+
single_button = gr.Button("Predict")
|
| 86 |
+
with gr.Row():
|
| 87 |
+
raw_image = gr.Image(label = "Localisation visualization")
|
| 88 |
+
output_image = gr.Image(label = "Labeled visualisation")
|
| 89 |
+
score_biodiv = gr.Number(label = "Biodiversity score")
|
| 90 |
+
|
| 91 |
+
with gr.Tab("TimeLapse"):
|
| 92 |
+
with gr.Row():
|
| 93 |
+
input_map_2 = gr.Plot().style()
|
| 94 |
+
with gr.Row():
|
| 95 |
+
timelapse_input_latitude = gr.Number(value=2.98, label="Latitude")
|
| 96 |
+
timelapse_input_longitude = gr.Number(value=48.81, label="Longitude")
|
| 97 |
+
timelapse_start_date = gr.Textbox(value='2020-05-01', label="Start Date")
|
| 98 |
+
timelapse_end_date = gr.Textbox(value='2020-06-30', label="End Date")
|
| 99 |
+
segmentation = gr.CheckboxGroup(choices=['month', 'year', '2months'], value=['month'], label="Select Segmentation Level:")
|
| 100 |
+
timelapse_button = gr.Button(value="Predict")
|
| 101 |
+
map = gr.Plot().style()
|
| 102 |
+
|
| 103 |
+
demo.load(get_geomap, [input_latitude, input_longitude], input_map)
|
| 104 |
+
single_button.click(get_geomap, [input_latitude, input_longitude], input_map)
|
| 105 |
+
single_button.click(partial(inference_on_location_and_month, model), inputs=[input_latitude, input_longitude, input_date], outputs=[raw_image, output_image,score_biodiv])
|
| 106 |
+
|
| 107 |
+
demo.load(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
|
| 108 |
+
timelapse_button.click(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
|
| 109 |
+
timelapse_button.click(segment_region, inputs=[timelapse_input_latitude, timelapse_input_longitude, timelapse_start_date, timelapse_end_date,segmentation], outputs=[map])
|
| 110 |
+
demo.launch(share=True)
|
biomap/configs/my_train_config.yml
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
output_root: '../'
|
| 2 |
+
pytorch_data_dir: '/home/duong_nguyen/pytorch-data'
|
| 3 |
+
experiment_name: "unet_7classes"
|
| 4 |
+
log_dir: "france"
|
| 5 |
+
# experiment_name: "unet"
|
| 6 |
+
# log_dir: "potsdam"
|
| 7 |
+
azureml_logging: False
|
| 8 |
+
submitting_to_aml: False
|
| 9 |
+
full_name: ~
|
| 10 |
+
|
| 11 |
+
# Loader params
|
| 12 |
+
num_workers: 24
|
| 13 |
+
max_steps: 80000
|
| 14 |
+
batch_size: 16
|
| 15 |
+
|
| 16 |
+
num_neighbors: 7
|
| 17 |
+
dataset_name: "directory"
|
| 18 |
+
# dataset_name: "potsdam"
|
| 19 |
+
|
| 20 |
+
# Used if dataset_name is "directory"
|
| 21 |
+
dir_dataset_name: "corine"
|
| 22 |
+
dir_dataset_n_classes: 7
|
| 23 |
+
|
| 24 |
+
has_labels: False
|
| 25 |
+
# crop_type: "five"
|
| 26 |
+
crop_type: ~
|
| 27 |
+
crop_ratio: .5
|
| 28 |
+
res: 224
|
| 29 |
+
loader_crop_type: "center"
|
| 30 |
+
|
| 31 |
+
# Model Params
|
| 32 |
+
extra_clusters: 0
|
| 33 |
+
use_true_labels: False
|
| 34 |
+
use_recalibrator: False
|
| 35 |
+
model_type: "vit_small"
|
| 36 |
+
arch: "dino"
|
| 37 |
+
use_fit_model: False
|
| 38 |
+
dino_feat_type: "feat"
|
| 39 |
+
projection_type: "nonlinear"
|
| 40 |
+
#projection_type: linear
|
| 41 |
+
dino_patch_size: 8
|
| 42 |
+
granularity: 1
|
| 43 |
+
continuous: True
|
| 44 |
+
dim: 70
|
| 45 |
+
dropout: True
|
| 46 |
+
zero_clamp: True
|
| 47 |
+
|
| 48 |
+
lr: 5e-4
|
| 49 |
+
pretrained_weights: ~
|
| 50 |
+
use_salience: False
|
| 51 |
+
stabalize: False
|
| 52 |
+
stop_at_zero: True
|
| 53 |
+
|
| 54 |
+
# Feature Contrastive params
|
| 55 |
+
pointwise: True
|
| 56 |
+
feature_samples: 11
|
| 57 |
+
neg_samples: 5
|
| 58 |
+
aug_alignment_weight: 0.0
|
| 59 |
+
|
| 60 |
+
correspondence_weight: 1.0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# # Corine vit small 24/11/22
|
| 64 |
+
neg_inter_weight: 0.63
|
| 65 |
+
pos_inter_weight: 0.25
|
| 66 |
+
pos_intra_weight: 0.67
|
| 67 |
+
neg_inter_shift: 0.46
|
| 68 |
+
pos_inter_shift: 0.02
|
| 69 |
+
pos_intra_shift: 0.08
|
| 70 |
+
|
| 71 |
+
# # Corine vit small 11/09/22
|
| 72 |
+
# neg_inter_weight: 0.63
|
| 73 |
+
# pos_inter_weight: 0.25
|
| 74 |
+
# pos_intra_weight: 0.67
|
| 75 |
+
# neg_inter_shift: 0.46
|
| 76 |
+
# pos_inter_shift: 0.24
|
| 77 |
+
# pos_intra_shift: 0.36
|
| 78 |
+
|
| 79 |
+
# # IAROA vit small 1/31/22
|
| 80 |
+
# neg_inter_weight: 0.63
|
| 81 |
+
# pos_inter_weight: 0.25
|
| 82 |
+
# pos_intra_weight: 0.67
|
| 83 |
+
# neg_inter_shift: 0.46
|
| 84 |
+
# pos_inter_shift: 0.12
|
| 85 |
+
# pos_intra_shift: 0.18
|
| 86 |
+
|
| 87 |
+
# Potsdam vit small 1/31/22
|
| 88 |
+
# neg_inter_weight: 0.63
|
| 89 |
+
# pos_inter_weight: 0.25
|
| 90 |
+
# pos_intra_weight: 0.67
|
| 91 |
+
# neg_inter_shift: 0.46
|
| 92 |
+
# pos_inter_shift: 0.02
|
| 93 |
+
# pos_intra_shift: 0.08
|
| 94 |
+
|
| 95 |
+
# Cocostuff27 vit small 1/31/22
|
| 96 |
+
#neg_inter_weight: 0.63
|
| 97 |
+
#pos_inter_weight: 0.25
|
| 98 |
+
#pos_intra_weight: 0.67
|
| 99 |
+
#neg_inter_shift: 0.66
|
| 100 |
+
#pos_inter_shift: 0.02
|
| 101 |
+
#pos_intra_shift: 0.08
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
## Cocostuff27 10/3 vit_base
|
| 105 |
+
|
| 106 |
+
#neg_inter_weight: 0.1538476246415498
|
| 107 |
+
#pos_inter_weight: 1
|
| 108 |
+
#pos_intra_weight: 0.1
|
| 109 |
+
#
|
| 110 |
+
#neg_inter_shift: 1
|
| 111 |
+
#pos_inter_shift: 0.2
|
| 112 |
+
#pos_intra_shift: 0.12
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
## Cocostuff27 10/3 vit_small
|
| 116 |
+
#neg_inter_weight: .63
|
| 117 |
+
#pos_inter_weight: .25
|
| 118 |
+
#pos_intra_weight: .67
|
| 119 |
+
#
|
| 120 |
+
#neg_inter_shift: .16
|
| 121 |
+
#pos_inter_shift: .02
|
| 122 |
+
#pos_intra_shift: .08
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## Cocostuff27 10/3 moco
|
| 127 |
+
#neg_inter_weight: .63
|
| 128 |
+
#pos_inter_weight: .25
|
| 129 |
+
#pos_intra_weight: .67
|
| 130 |
+
#
|
| 131 |
+
#neg_inter_shift: .26
|
| 132 |
+
#pos_inter_shift: .36
|
| 133 |
+
#pos_intra_shift: .32
|
| 134 |
+
|
| 135 |
+
#pos_inter_shift: .12
|
| 136 |
+
#pos_intra_shift: .18
|
| 137 |
+
|
| 138 |
+
## Cocostuff27
|
| 139 |
+
#neg_inter_weight: .72
|
| 140 |
+
#pos_inter_weight: .80
|
| 141 |
+
#pos_intra_weight: .29
|
| 142 |
+
#
|
| 143 |
+
#neg_inter_shift: .86
|
| 144 |
+
#pos_inter_shift: .04
|
| 145 |
+
#pos_intra_shift: .34
|
| 146 |
+
|
| 147 |
+
# Cityscapes 10/3
|
| 148 |
+
|
| 149 |
+
# neg_inter_weight: 0.9058762625226623
|
| 150 |
+
# pos_inter_weight: 0.577453483136995
|
| 151 |
+
# pos_intra_weight: 1
|
| 152 |
+
|
| 153 |
+
# neg_inter_shift: 0.31361241889448443
|
| 154 |
+
# pos_inter_shift: 0.1754346515479633
|
| 155 |
+
# pos_intra_shift: 0.45828472207
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# Cityscapes
|
| 159 |
+
#neg_inter_weight: .72
|
| 160 |
+
#pos_inter_weight: .18
|
| 161 |
+
#pos_intra_weight: .46
|
| 162 |
+
#
|
| 163 |
+
#neg_inter_shift: .25
|
| 164 |
+
#pos_inter_shift: .20
|
| 165 |
+
#pos_intra_shift: .25
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
rec_weight: 0.0
|
| 169 |
+
repulsion_weight: 0.0
|
| 170 |
+
|
| 171 |
+
# CRF Params
|
| 172 |
+
crf_weight: 0.0
|
| 173 |
+
alpha: .5
|
| 174 |
+
beta: .15
|
| 175 |
+
gamma: .05
|
| 176 |
+
w1: 10.0
|
| 177 |
+
w2: 3.0
|
| 178 |
+
shift: 0.00
|
| 179 |
+
crf_samples: 1000
|
| 180 |
+
color_space: "rgb"
|
| 181 |
+
|
| 182 |
+
reset_probe_steps: ~
|
| 183 |
+
|
| 184 |
+
# Logging params
|
| 185 |
+
n_images: 5
|
| 186 |
+
scalar_log_freq: 10
|
| 187 |
+
checkpoint_freq: 50
|
| 188 |
+
val_freq: 100
|
| 189 |
+
hist_freq: 100
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
hydra:
|
| 193 |
+
run:
|
| 194 |
+
dir: "."
|
| 195 |
+
output_subdir: ~
|
| 196 |
+
#job_logging: "disabled"
|
| 197 |
+
#hydra_logging: "disabled"
|
biomap/data.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from os.path import join
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.multiprocessing
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from scipy.io import loadmat
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from torchvision.datasets.cityscapes import Cityscapes
|
| 12 |
+
from torchvision.transforms.functional import to_pil_image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def bit_get(val, idx):
|
| 17 |
+
"""Gets the bit value.
|
| 18 |
+
Args:
|
| 19 |
+
val: Input value, int or numpy int array.
|
| 20 |
+
idx: Which bit of the input val.
|
| 21 |
+
Returns:
|
| 22 |
+
The "idx"-th bit of input val.
|
| 23 |
+
"""
|
| 24 |
+
return (val >> idx) & 1
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_pascal_label_colormap():
|
| 28 |
+
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
|
| 29 |
+
Returns:
|
| 30 |
+
A colormap for visualizing segmentation results.
|
| 31 |
+
"""
|
| 32 |
+
colormap = np.zeros((512, 3), dtype=int)
|
| 33 |
+
ind = np.arange(512, dtype=int)
|
| 34 |
+
|
| 35 |
+
for shift in reversed(list(range(8))):
|
| 36 |
+
for channel in range(3):
|
| 37 |
+
colormap[:, channel] |= bit_get(ind, channel) << shift
|
| 38 |
+
ind >>= 3
|
| 39 |
+
|
| 40 |
+
return colormap
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_cityscapes_colormap():
|
| 44 |
+
colors = [(128, 64, 128),
|
| 45 |
+
(244, 35, 232),
|
| 46 |
+
(250, 170, 160),
|
| 47 |
+
(230, 150, 140),
|
| 48 |
+
(70, 70, 70),
|
| 49 |
+
(102, 102, 156),
|
| 50 |
+
(190, 153, 153),
|
| 51 |
+
(180, 165, 180),
|
| 52 |
+
(150, 100, 100),
|
| 53 |
+
(150, 120, 90),
|
| 54 |
+
(153, 153, 153),
|
| 55 |
+
(153, 153, 153),
|
| 56 |
+
(250, 170, 30),
|
| 57 |
+
(220, 220, 0),
|
| 58 |
+
(107, 142, 35),
|
| 59 |
+
(152, 251, 152),
|
| 60 |
+
(70, 130, 180),
|
| 61 |
+
(220, 20, 60),
|
| 62 |
+
(255, 0, 0),
|
| 63 |
+
(0, 0, 142),
|
| 64 |
+
(0, 0, 70),
|
| 65 |
+
(0, 60, 100),
|
| 66 |
+
(0, 0, 90),
|
| 67 |
+
(0, 0, 110),
|
| 68 |
+
(0, 80, 100),
|
| 69 |
+
(0, 0, 230),
|
| 70 |
+
(119, 11, 32),
|
| 71 |
+
(0, 0, 0)]
|
| 72 |
+
return np.array(colors)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class DirectoryDataset(Dataset):
|
| 76 |
+
def __init__(self, root, path, image_set, transform, target_transform):
|
| 77 |
+
super(DirectoryDataset, self).__init__()
|
| 78 |
+
self.split = image_set
|
| 79 |
+
self.dir = join(root, path)
|
| 80 |
+
self.img_dir = join(self.dir, "imgs", self.split)
|
| 81 |
+
self.label_dir = join(self.dir, "labels", self.split)
|
| 82 |
+
|
| 83 |
+
self.transform = transform
|
| 84 |
+
self.target_transform = target_transform
|
| 85 |
+
|
| 86 |
+
self.img_files = np.array(sorted(os.listdir(self.img_dir)))
|
| 87 |
+
assert len(self.img_files) > 0
|
| 88 |
+
if os.path.exists(join(self.dir, "labels")):
|
| 89 |
+
self.label_files = np.array(sorted(os.listdir(self.label_dir)))
|
| 90 |
+
assert len(self.img_files) == len(self.label_files)
|
| 91 |
+
else:
|
| 92 |
+
self.label_files = None
|
| 93 |
+
self.fine_to_coarse = {0: 0,
|
| 94 |
+
1: 1,
|
| 95 |
+
2: 2,
|
| 96 |
+
3: 3,
|
| 97 |
+
4: 4,
|
| 98 |
+
5: 5,
|
| 99 |
+
6: 6,
|
| 100 |
+
7: -1,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, index):
|
| 104 |
+
image_fn = self.img_files[index]
|
| 105 |
+
img = Image.open(join(self.img_dir, image_fn))
|
| 106 |
+
|
| 107 |
+
if self.label_files is not None:
|
| 108 |
+
label_fn = self.label_files[index]
|
| 109 |
+
label = Image.open(join(self.label_dir, label_fn))
|
| 110 |
+
|
| 111 |
+
seed = np.random.randint(2147483647)
|
| 112 |
+
random.seed(seed)
|
| 113 |
+
torch.manual_seed(seed)
|
| 114 |
+
img = self.transform(img)
|
| 115 |
+
|
| 116 |
+
if self.label_files is not None:
|
| 117 |
+
random.seed(seed)
|
| 118 |
+
torch.manual_seed(seed)
|
| 119 |
+
label = self.target_transform(label)
|
| 120 |
+
new_label_map = torch.zeros_like(label)
|
| 121 |
+
for fine, coarse in self.fine_to_coarse.items():
|
| 122 |
+
new_label_map[label == fine] = coarse
|
| 123 |
+
label = new_label_map
|
| 124 |
+
else:
|
| 125 |
+
label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) - 1
|
| 126 |
+
|
| 127 |
+
mask = (label > 0).to(torch.float32)
|
| 128 |
+
return img, label, mask
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def __len__(self):
|
| 132 |
+
return len(self.img_files)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Potsdam(Dataset):
|
| 136 |
+
def __init__(self, root, image_set, transform, target_transform, coarse_labels):
|
| 137 |
+
super(Potsdam, self).__init__()
|
| 138 |
+
self.split = image_set
|
| 139 |
+
self.root = os.path.join(root, "potsdam")
|
| 140 |
+
self.transform = transform
|
| 141 |
+
self.target_transform = target_transform
|
| 142 |
+
split_files = {
|
| 143 |
+
"train": ["labelled_train.txt"],
|
| 144 |
+
"unlabelled_train": ["unlabelled_train.txt"],
|
| 145 |
+
# "train": ["unlabelled_train.txt"],
|
| 146 |
+
"val": ["labelled_test.txt"],
|
| 147 |
+
"train+val": ["labelled_train.txt", "labelled_test.txt"],
|
| 148 |
+
"all": ["all.txt"]
|
| 149 |
+
}
|
| 150 |
+
assert self.split in split_files.keys()
|
| 151 |
+
|
| 152 |
+
self.files = []
|
| 153 |
+
for split_file in split_files[self.split]:
|
| 154 |
+
with open(join(self.root, split_file), "r") as f:
|
| 155 |
+
self.files.extend(fn.rstrip() for fn in f.readlines())
|
| 156 |
+
|
| 157 |
+
self.coarse_labels = coarse_labels
|
| 158 |
+
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
|
| 159 |
+
1: 1, 5: 1, # buildings and clutter
|
| 160 |
+
2: 2, 3: 2, # vegetation and trees
|
| 161 |
+
255: -1
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def __getitem__(self, index):
|
| 165 |
+
image_id = self.files[index]
|
| 166 |
+
img = loadmat(join(self.root, "imgs", image_id + ".mat"))["img"]
|
| 167 |
+
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
|
| 168 |
+
try:
|
| 169 |
+
label = loadmat(join(self.root, "gt", image_id + ".mat"))["gt"]
|
| 170 |
+
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
|
| 171 |
+
except FileNotFoundError:
|
| 172 |
+
label = to_pil_image(torch.ones(1, img.height, img.width))
|
| 173 |
+
|
| 174 |
+
seed = np.random.randint(2147483647)
|
| 175 |
+
random.seed(seed)
|
| 176 |
+
torch.manual_seed(seed)
|
| 177 |
+
img = self.transform(img)
|
| 178 |
+
|
| 179 |
+
random.seed(seed)
|
| 180 |
+
torch.manual_seed(seed)
|
| 181 |
+
label = self.target_transform(label).squeeze(0)
|
| 182 |
+
if self.coarse_labels:
|
| 183 |
+
new_label_map = torch.zeros_like(label)
|
| 184 |
+
for fine, coarse in self.fine_to_coarse.items():
|
| 185 |
+
new_label_map[label == fine] = coarse
|
| 186 |
+
label = new_label_map
|
| 187 |
+
|
| 188 |
+
mask = (label > 0).to(torch.float32)
|
| 189 |
+
return img, label, mask
|
| 190 |
+
|
| 191 |
+
def __len__(self):
|
| 192 |
+
return len(self.files)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class PotsdamRaw(Dataset):
|
| 196 |
+
def __init__(self, root, image_set, transform, target_transform, coarse_labels):
|
| 197 |
+
super(PotsdamRaw, self).__init__()
|
| 198 |
+
self.split = image_set
|
| 199 |
+
self.root = os.path.join(root, "potsdamraw", "processed")
|
| 200 |
+
self.transform = transform
|
| 201 |
+
self.target_transform = target_transform
|
| 202 |
+
self.files = []
|
| 203 |
+
for im_num in range(38):
|
| 204 |
+
for i_h in range(15):
|
| 205 |
+
for i_w in range(15):
|
| 206 |
+
self.files.append("{}_{}_{}.mat".format(im_num, i_h, i_w))
|
| 207 |
+
|
| 208 |
+
self.coarse_labels = coarse_labels
|
| 209 |
+
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
|
| 210 |
+
1: 1, 5: 1, # buildings and clutter
|
| 211 |
+
2: 2, 3: 2, # vegetation and trees
|
| 212 |
+
255: -1
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
def __getitem__(self, index):
|
| 216 |
+
image_id = self.files[index]
|
| 217 |
+
img = loadmat(join(self.root, "imgs", image_id))["img"]
|
| 218 |
+
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
|
| 219 |
+
try:
|
| 220 |
+
label = loadmat(join(self.root, "gt", image_id))["gt"]
|
| 221 |
+
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
|
| 222 |
+
except FileNotFoundError:
|
| 223 |
+
label = to_pil_image(torch.ones(1, img.height, img.width))
|
| 224 |
+
|
| 225 |
+
seed = np.random.randint(2147483647)
|
| 226 |
+
random.seed(seed)
|
| 227 |
+
torch.manual_seed(seed)
|
| 228 |
+
img = self.transform(img)
|
| 229 |
+
|
| 230 |
+
random.seed(seed)
|
| 231 |
+
torch.manual_seed(seed)
|
| 232 |
+
label = self.target_transform(label).squeeze(0)
|
| 233 |
+
if self.coarse_labels:
|
| 234 |
+
new_label_map = torch.zeros_like(label)
|
| 235 |
+
for fine, coarse in self.fine_to_coarse.items():
|
| 236 |
+
new_label_map[label == fine] = coarse
|
| 237 |
+
label = new_label_map
|
| 238 |
+
|
| 239 |
+
mask = (label > 0).to(torch.float32)
|
| 240 |
+
return img, label, mask
|
| 241 |
+
|
| 242 |
+
def __len__(self):
|
| 243 |
+
return len(self.files)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class Coco(Dataset):
|
| 247 |
+
def __init__(self, root, image_set, transform, target_transform,
|
| 248 |
+
coarse_labels, exclude_things, subset=None):
|
| 249 |
+
super(Coco, self).__init__()
|
| 250 |
+
self.split = image_set
|
| 251 |
+
self.root = join(root, "cocostuff")
|
| 252 |
+
self.coarse_labels = coarse_labels
|
| 253 |
+
self.transform = transform
|
| 254 |
+
self.label_transform = target_transform
|
| 255 |
+
self.subset = subset
|
| 256 |
+
self.exclude_things = exclude_things
|
| 257 |
+
|
| 258 |
+
if self.subset is None:
|
| 259 |
+
self.image_list = "Coco164kFull_Stuff_Coarse.txt"
|
| 260 |
+
elif self.subset == 6: # IIC Coarse
|
| 261 |
+
self.image_list = "Coco164kFew_Stuff_6.txt"
|
| 262 |
+
elif self.subset == 7: # IIC Fine
|
| 263 |
+
self.image_list = "Coco164kFull_Stuff_Coarse_7.txt"
|
| 264 |
+
|
| 265 |
+
assert self.split in ["train", "val", "train+val"]
|
| 266 |
+
split_dirs = {
|
| 267 |
+
"train": ["train2017"],
|
| 268 |
+
"val": ["val2017"],
|
| 269 |
+
"train+val": ["train2017", "val2017"]
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
self.image_files = []
|
| 273 |
+
self.label_files = []
|
| 274 |
+
for split_dir in split_dirs[self.split]:
|
| 275 |
+
with open(join(self.root, "curated", split_dir, self.image_list), "r") as f:
|
| 276 |
+
img_ids = [fn.rstrip() for fn in f.readlines()]
|
| 277 |
+
for img_id in img_ids:
|
| 278 |
+
self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg"))
|
| 279 |
+
self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png"))
|
| 280 |
+
|
| 281 |
+
self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8,
|
| 282 |
+
13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7,
|
| 283 |
+
25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10,
|
| 284 |
+
37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5,
|
| 285 |
+
49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2,
|
| 286 |
+
61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0,
|
| 287 |
+
73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4,
|
| 288 |
+
85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22,
|
| 289 |
+
97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15,
|
| 290 |
+
107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13,
|
| 291 |
+
117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24,
|
| 292 |
+
127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17,
|
| 293 |
+
137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21,
|
| 294 |
+
147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23,
|
| 295 |
+
157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17,
|
| 296 |
+
167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18,
|
| 297 |
+
177: 26, 178: 26, 179: 19, 180: 19, 181: 24}
|
| 298 |
+
|
| 299 |
+
self._label_names = [
|
| 300 |
+
"ground-stuff",
|
| 301 |
+
"plant-stuff",
|
| 302 |
+
"sky-stuff",
|
| 303 |
+
]
|
| 304 |
+
self.cocostuff3_coarse_classes = [23, 22, 21]
|
| 305 |
+
self.first_stuff_index = 12
|
| 306 |
+
|
| 307 |
+
def __getitem__(self, index):
|
| 308 |
+
image_path = self.image_files[index]
|
| 309 |
+
label_path = self.label_files[index]
|
| 310 |
+
seed = np.random.randint(2147483647)
|
| 311 |
+
random.seed(seed)
|
| 312 |
+
torch.manual_seed(seed)
|
| 313 |
+
img = self.transform(Image.open(image_path).convert("RGB"))
|
| 314 |
+
|
| 315 |
+
random.seed(seed)
|
| 316 |
+
torch.manual_seed(seed)
|
| 317 |
+
label = self.label_transform(Image.open(label_path)).squeeze(0)
|
| 318 |
+
label[label == 255] = -1 # to be consistent with 10k
|
| 319 |
+
coarse_label = torch.zeros_like(label)
|
| 320 |
+
for fine, coarse in self.fine_to_coarse.items():
|
| 321 |
+
coarse_label[label == fine] = coarse
|
| 322 |
+
coarse_label[label == -1] = -1
|
| 323 |
+
|
| 324 |
+
if self.coarse_labels:
|
| 325 |
+
coarser_labels = -torch.ones_like(label)
|
| 326 |
+
for i, c in enumerate(self.cocostuff3_coarse_classes):
|
| 327 |
+
coarser_labels[coarse_label == c] = i
|
| 328 |
+
return img, coarser_labels, coarser_labels >= 0
|
| 329 |
+
else:
|
| 330 |
+
if self.exclude_things:
|
| 331 |
+
return img, coarse_label - self.first_stuff_index, (coarse_label >= self.first_stuff_index)
|
| 332 |
+
else:
|
| 333 |
+
return img, coarse_label, coarse_label >= 0
|
| 334 |
+
|
| 335 |
+
def __len__(self):
|
| 336 |
+
return len(self.image_files)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class CityscapesSeg(Dataset):
|
| 340 |
+
def __init__(self, root, image_set, transform, target_transform):
|
| 341 |
+
super(CityscapesSeg, self).__init__()
|
| 342 |
+
self.split = image_set
|
| 343 |
+
self.root = join(root, "cityscapes")
|
| 344 |
+
if image_set == "train":
|
| 345 |
+
# our_image_set = "train_extra"
|
| 346 |
+
# mode = "coarse"
|
| 347 |
+
our_image_set = "train"
|
| 348 |
+
mode = "fine"
|
| 349 |
+
else:
|
| 350 |
+
our_image_set = image_set
|
| 351 |
+
mode = "fine"
|
| 352 |
+
self.inner_loader = Cityscapes(self.root, our_image_set,
|
| 353 |
+
mode=mode,
|
| 354 |
+
target_type="semantic",
|
| 355 |
+
transform=None,
|
| 356 |
+
target_transform=None)
|
| 357 |
+
self.transform = transform
|
| 358 |
+
self.target_transform = target_transform
|
| 359 |
+
self.first_nonvoid = 7
|
| 360 |
+
|
| 361 |
+
def __getitem__(self, index):
|
| 362 |
+
if self.transform is not None:
|
| 363 |
+
image, target = self.inner_loader[index]
|
| 364 |
+
|
| 365 |
+
seed = np.random.randint(2147483647)
|
| 366 |
+
random.seed(seed)
|
| 367 |
+
torch.manual_seed(seed)
|
| 368 |
+
image = self.transform(image)
|
| 369 |
+
random.seed(seed)
|
| 370 |
+
torch.manual_seed(seed)
|
| 371 |
+
target = self.target_transform(target)
|
| 372 |
+
|
| 373 |
+
target = target - self.first_nonvoid
|
| 374 |
+
target[target < 0] = -1
|
| 375 |
+
mask = target == -1
|
| 376 |
+
return image, target.squeeze(0), mask
|
| 377 |
+
else:
|
| 378 |
+
return self.inner_loader[index]
|
| 379 |
+
|
| 380 |
+
def __len__(self):
|
| 381 |
+
return len(self.inner_loader)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class CroppedDataset(Dataset):
|
| 385 |
+
def __init__(self, root, dataset_name, crop_type, crop_ratio, image_set, transform, target_transform):
|
| 386 |
+
super(CroppedDataset, self).__init__()
|
| 387 |
+
self.dataset_name = dataset_name
|
| 388 |
+
self.split = image_set
|
| 389 |
+
self.root = join(root, "cropped", "{}_{}_crop_{}".format(dataset_name, crop_type, crop_ratio))
|
| 390 |
+
self.transform = transform
|
| 391 |
+
self.target_transform = target_transform
|
| 392 |
+
self.img_dir = join(self.root, "img", self.split)
|
| 393 |
+
self.label_dir = join(self.root, "label", self.split)
|
| 394 |
+
self.num_images = len(os.listdir(self.img_dir))
|
| 395 |
+
assert self.num_images == len(os.listdir(self.label_dir))
|
| 396 |
+
|
| 397 |
+
def __getitem__(self, index):
|
| 398 |
+
image = Image.open(join(self.img_dir, "{}.jpg".format(index))).convert('RGB')
|
| 399 |
+
target = Image.open(join(self.label_dir, "{}.png".format(index)))
|
| 400 |
+
|
| 401 |
+
seed = np.random.randint(2147483647)
|
| 402 |
+
random.seed(seed)
|
| 403 |
+
torch.manual_seed(seed)
|
| 404 |
+
image = self.transform(image)
|
| 405 |
+
random.seed(seed)
|
| 406 |
+
torch.manual_seed(seed)
|
| 407 |
+
target = self.target_transform(target)
|
| 408 |
+
|
| 409 |
+
target = target - 1
|
| 410 |
+
mask = target == -1
|
| 411 |
+
return image, target.squeeze(0), mask
|
| 412 |
+
|
| 413 |
+
def __len__(self):
|
| 414 |
+
return self.num_images
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class MaterializedDataset(Dataset):
|
| 418 |
+
|
| 419 |
+
def __init__(self, ds):
|
| 420 |
+
self.ds = ds
|
| 421 |
+
self.materialized = []
|
| 422 |
+
loader = DataLoader(ds, num_workers=12, collate_fn=lambda l: l[0])
|
| 423 |
+
for batch in tqdm(loader):
|
| 424 |
+
self.materialized.append(batch)
|
| 425 |
+
|
| 426 |
+
def __len__(self):
|
| 427 |
+
return len(self.ds)
|
| 428 |
+
|
| 429 |
+
def __getitem__(self, ind):
|
| 430 |
+
return self.materialized[ind]
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class ContrastiveSegDataset(Dataset):
|
| 434 |
+
def __init__(self,
|
| 435 |
+
pytorch_data_dir,
|
| 436 |
+
dataset_name,
|
| 437 |
+
crop_type,
|
| 438 |
+
image_set,
|
| 439 |
+
transform,
|
| 440 |
+
target_transform,
|
| 441 |
+
cfg,
|
| 442 |
+
aug_geometric_transform=None,
|
| 443 |
+
aug_photometric_transform=None,
|
| 444 |
+
num_neighbors=5,
|
| 445 |
+
compute_knns=False,
|
| 446 |
+
mask=False,
|
| 447 |
+
pos_labels=False,
|
| 448 |
+
pos_images=False,
|
| 449 |
+
extra_transform=None,
|
| 450 |
+
model_type_override=None
|
| 451 |
+
):
|
| 452 |
+
super(ContrastiveSegDataset).__init__()
|
| 453 |
+
self.num_neighbors = num_neighbors
|
| 454 |
+
self.image_set = image_set
|
| 455 |
+
self.dataset_name = dataset_name
|
| 456 |
+
self.mask = mask
|
| 457 |
+
self.pos_labels = pos_labels
|
| 458 |
+
self.pos_images = pos_images
|
| 459 |
+
self.extra_transform = extra_transform
|
| 460 |
+
|
| 461 |
+
if dataset_name == "potsdam":
|
| 462 |
+
self.n_classes = 3
|
| 463 |
+
dataset_class = Potsdam
|
| 464 |
+
extra_args = dict(coarse_labels=True)
|
| 465 |
+
elif dataset_name == "potsdamraw":
|
| 466 |
+
self.n_classes = 3
|
| 467 |
+
dataset_class = PotsdamRaw
|
| 468 |
+
extra_args = dict(coarse_labels=True)
|
| 469 |
+
elif dataset_name == "directory":
|
| 470 |
+
self.n_classes = cfg.dir_dataset_n_classes
|
| 471 |
+
dataset_class = DirectoryDataset
|
| 472 |
+
extra_args = dict(path=cfg.dir_dataset_name)
|
| 473 |
+
elif dataset_name == "cityscapes" and crop_type is None:
|
| 474 |
+
self.n_classes = 27
|
| 475 |
+
dataset_class = CityscapesSeg
|
| 476 |
+
extra_args = dict()
|
| 477 |
+
elif dataset_name == "cityscapes" and crop_type is not None:
|
| 478 |
+
self.n_classes = 27
|
| 479 |
+
dataset_class = CroppedDataset
|
| 480 |
+
extra_args = dict(dataset_name="cityscapes", crop_type=crop_type, crop_ratio=cfg.crop_ratio)
|
| 481 |
+
elif dataset_name == "cocostuff3":
|
| 482 |
+
self.n_classes = 3
|
| 483 |
+
dataset_class = Coco
|
| 484 |
+
extra_args = dict(coarse_labels=True, subset=6, exclude_things=True)
|
| 485 |
+
elif dataset_name == "cocostuff15":
|
| 486 |
+
self.n_classes = 15
|
| 487 |
+
dataset_class = Coco
|
| 488 |
+
extra_args = dict(coarse_labels=False, subset=7, exclude_things=True)
|
| 489 |
+
elif dataset_name == "cocostuff27" and crop_type is not None:
|
| 490 |
+
self.n_classes = 27
|
| 491 |
+
dataset_class = CroppedDataset
|
| 492 |
+
extra_args = dict(dataset_name="cocostuff27", crop_type=cfg.crop_type, crop_ratio=cfg.crop_ratio)
|
| 493 |
+
elif dataset_name == "cocostuff27" and crop_type is None:
|
| 494 |
+
self.n_classes = 27
|
| 495 |
+
dataset_class = Coco
|
| 496 |
+
extra_args = dict(coarse_labels=False, subset=None, exclude_things=False)
|
| 497 |
+
if image_set == "val":
|
| 498 |
+
extra_args["subset"] = 7
|
| 499 |
+
else:
|
| 500 |
+
raise ValueError("Unknown dataset: {}".format(dataset_name))
|
| 501 |
+
|
| 502 |
+
self.aug_geometric_transform = aug_geometric_transform
|
| 503 |
+
self.aug_photometric_transform = aug_photometric_transform
|
| 504 |
+
|
| 505 |
+
self.dataset = dataset_class(
|
| 506 |
+
root=pytorch_data_dir,
|
| 507 |
+
image_set=self.image_set,
|
| 508 |
+
transform=transform,
|
| 509 |
+
target_transform=target_transform, **extra_args)
|
| 510 |
+
|
| 511 |
+
if model_type_override is not None:
|
| 512 |
+
model_type = model_type_override
|
| 513 |
+
else:
|
| 514 |
+
model_type = cfg.model_type
|
| 515 |
+
|
| 516 |
+
nice_dataset_name = cfg.dir_dataset_name if dataset_name == "directory" else dataset_name
|
| 517 |
+
feature_cache_file = join(pytorch_data_dir, "nns", "nns_{}_{}_{}_{}_{}.npz".format(
|
| 518 |
+
model_type, nice_dataset_name, image_set, crop_type, cfg.res))
|
| 519 |
+
if pos_labels or pos_images:
|
| 520 |
+
if not os.path.exists(feature_cache_file) or compute_knns:
|
| 521 |
+
raise ValueError("could not find nn file {} please run precompute_knns".format(feature_cache_file))
|
| 522 |
+
else:
|
| 523 |
+
loaded = np.load(feature_cache_file)
|
| 524 |
+
self.nns = loaded["nns"]
|
| 525 |
+
assert len(self.dataset) == self.nns.shape[0]
|
| 526 |
+
|
| 527 |
+
def __len__(self):
|
| 528 |
+
return len(self.dataset)
|
| 529 |
+
|
| 530 |
+
def _set_seed(self, seed):
|
| 531 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
| 532 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
| 533 |
+
|
| 534 |
+
def __getitem__(self, ind):
|
| 535 |
+
pack = self.dataset[ind]
|
| 536 |
+
|
| 537 |
+
if self.pos_images or self.pos_labels:
|
| 538 |
+
ind_pos = self.nns[ind][torch.randint(low=1, high=self.num_neighbors + 1, size=[]).item()]
|
| 539 |
+
pack_pos = self.dataset[ind_pos]
|
| 540 |
+
|
| 541 |
+
seed = np.random.randint(2147483647) # make a seed with numpy generator
|
| 542 |
+
|
| 543 |
+
self._set_seed(seed)
|
| 544 |
+
coord_entries = torch.meshgrid([torch.linspace(-1, 1, pack[0].shape[1]),
|
| 545 |
+
torch.linspace(-1, 1, pack[0].shape[2])])
|
| 546 |
+
coord = torch.cat([t.unsqueeze(0) for t in coord_entries], 0)
|
| 547 |
+
|
| 548 |
+
if self.extra_transform is not None:
|
| 549 |
+
extra_trans = self.extra_transform
|
| 550 |
+
else:
|
| 551 |
+
extra_trans = lambda i, x: x
|
| 552 |
+
|
| 553 |
+
def squeeze_tuple(label_raw):
|
| 554 |
+
if type(label_raw) == tuple:
|
| 555 |
+
return tuple(x.squeeze() for x in label_raw)
|
| 556 |
+
else:
|
| 557 |
+
return label_raw.squeeze()
|
| 558 |
+
ret = {
|
| 559 |
+
"ind": ind,
|
| 560 |
+
"img": extra_trans(ind, pack[0]),
|
| 561 |
+
"label": squeeze_tuple(extra_trans(ind, pack[1]))
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
if self.pos_images:
|
| 565 |
+
ret["img_pos"] = extra_trans(ind, pack_pos[0])
|
| 566 |
+
ret["ind_pos"] = ind_pos
|
| 567 |
+
|
| 568 |
+
if self.mask:
|
| 569 |
+
ret["mask"] = pack[2]
|
| 570 |
+
|
| 571 |
+
if self.pos_labels:
|
| 572 |
+
ret["label_pos"] = squeeze_tuple(extra_trans(ind, pack_pos[1]))
|
| 573 |
+
ret["mask_pos"] = pack_pos[2]
|
| 574 |
+
|
| 575 |
+
if self.aug_photometric_transform is not None:
|
| 576 |
+
img_aug = self.aug_photometric_transform(self.aug_geometric_transform(pack[0]))
|
| 577 |
+
|
| 578 |
+
self._set_seed(seed)
|
| 579 |
+
coord_aug = self.aug_geometric_transform(coord)
|
| 580 |
+
|
| 581 |
+
ret["img_aug"] = img_aug
|
| 582 |
+
ret["coord_aug"] = coord_aug.permute(1, 2, 0)
|
| 583 |
+
|
| 584 |
+
return ret
|
biomap/dataset_generator/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .data_loader import DataLoader
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'DataLoader',
|
| 6 |
+
]
|
biomap/dataset_generator/data_loader.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
import ee
|
| 3 |
+
from func_timeout import func_set_timeout
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import requests
|
| 7 |
+
import tempfile
|
| 8 |
+
import io
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import functools
|
| 11 |
+
import re # Used in an eval statement
|
| 12 |
+
from typing import List
|
| 13 |
+
from typing import Union
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DataLoader:
|
| 18 |
+
"""
|
| 19 |
+
Main class for loading and exploring data from satellite images.
|
| 20 |
+
The goal is to load an ImageCollection and to filter that collection according to needs, with methods like
|
| 21 |
+
filter, filterDate, filterBounds, select. These will work just like earth engine's methods with the same names.
|
| 22 |
+
|
| 23 |
+
This class, just like earth engine, works with lazy loading and compute. This means that running filterBounds
|
| 24 |
+
will not actually filter the image collection until required, e.g. when counting the images by accessing .count
|
| 25 |
+
property.
|
| 26 |
+
However, it will only load once the information it needs, unless additional filtering is made.
|
| 27 |
+
|
| 28 |
+
This works thanks to the signal_change decorator. If you develop a new filtering method for this class,
|
| 29 |
+
you will need to decorate your method with @signal_change.
|
| 30 |
+
In addition, if you develop a new method that will require to run getInfo to actually load data from
|
| 31 |
+
Google Earth Engine, you will need to use _get_timeout_info(your object before getInfo). This will run
|
| 32 |
+
getInfo with a timeout (currently set to 10 seconds).
|
| 33 |
+
It is important to use a timeout to avoid unexpected run times.
|
| 34 |
+
|
| 35 |
+
Usage:
|
| 36 |
+
>>> dl = DataLoader(satellite_name="COPERNICUS/S2_SR", \
|
| 37 |
+
start_date='2021-01-01', \
|
| 38 |
+
end_date='2021-01-15', \
|
| 39 |
+
bands=["TCI_R", "TCI_G", "TCI_B"], \
|
| 40 |
+
geographic_bounds=ee.Geometry.Point(*[5.238728194366604, 44.474864056855935]).buffer(500) \
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
Get a pandas dataframe with all pixel values as a timeseries:
|
| 44 |
+
>>> dl.getRegion(dl.bounds, 500)
|
| 45 |
+
>>> dl.region.head(2)
|
| 46 |
+
[Out]
|
| 47 |
+
id longitude latitude time B1 B2 B3 B4 B5 B6 ... WVP SCL TCI_R TCI_G TCI_B MSK_CLDPRB MSK_SNWPRB QA10 QA20 QA60
|
| 48 |
+
0 20210102T104441_20210102T104435_T31TFK 5.234932 44.473344 2021-01-02 10:48:36.299 6297 5955 5768 5773 5965 5883 ... 393 8 255 255 255 0 95 0 0 1024
|
| 49 |
+
1 20210104T103329_20210104T103331_T31TFK 5.234932 44.473344 2021-01-04 10:38:38.304 5547 5355 5184 5090 5254 5229 ... 314 9 255 255 255 29 9 0 0 1024
|
| 50 |
+
|
| 51 |
+
>>> dl.date_range
|
| 52 |
+
[Out]
|
| 53 |
+
{'max': datetime.datetime(2021, 1, 14, 11, 38, 39, 208000),
|
| 54 |
+
'min': datetime.datetime(2021, 1, 2, 11, 48, 36, 299000)}
|
| 55 |
+
|
| 56 |
+
>>> dl.count
|
| 57 |
+
[Out]
|
| 58 |
+
6
|
| 59 |
+
|
| 60 |
+
>>> dl.collection_info # constains a html description of the dataset in "description"
|
| 61 |
+
|
| 62 |
+
>>> dl.image_ids
|
| 63 |
+
[Out]
|
| 64 |
+
['COPERNICUS/S2_SR/20210102T104441_20210102T104435_T31TFK',
|
| 65 |
+
'COPERNICUS/S2_SR/20210104T103329_20210104T103331_T31TFK',
|
| 66 |
+
'COPERNICUS/S2_SR/20210107T104329_20210107T104328_T31TFK',
|
| 67 |
+
'COPERNICUS/S2_SR/20210109T103421_20210109T103431_T31TFK',
|
| 68 |
+
'COPERNICUS/S2_SR/20210112T104411_20210112T104438_T31TFK',
|
| 69 |
+
'COPERNICUS/S2_SR/20210114T103309_20210114T103305_T31TFK']
|
| 70 |
+
|
| 71 |
+
# Download the image
|
| 72 |
+
>>> img = dl.download_image(dl.image_ids[3])
|
| 73 |
+
|
| 74 |
+
# Download all images as a list
|
| 75 |
+
>>> imgs = dl.download_all_images(scale=1)
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self,
|
| 79 |
+
satellite_name: str,
|
| 80 |
+
bands: Union[List, str] = None,
|
| 81 |
+
start_date: str = None,
|
| 82 |
+
end_date: str = None,
|
| 83 |
+
geographic_bounds: ee.geometry = None,
|
| 84 |
+
scale: int = 10,
|
| 85 |
+
crs: str = "EPSG:32630"
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
satellite_name: satellite to use. Examples: COPERNICUS/S2_SR, COPERNICUS/CORINE/V20/100m.
|
| 91 |
+
See https://developers.google.com/earth-engine/datasets for the full list.
|
| 92 |
+
bands: list of bands to load.
|
| 93 |
+
start_date: lowest possible date. Might be lower than the actual date of the first picture.
|
| 94 |
+
end_date: Latest possible date.
|
| 95 |
+
geographic_bounds: Region of interest.
|
| 96 |
+
"""
|
| 97 |
+
self.satellite_name = satellite_name
|
| 98 |
+
if isinstance(bands, str):
|
| 99 |
+
bands = [bands]
|
| 100 |
+
self.bands = bands if bands is not None else list()
|
| 101 |
+
if start_date is None or end_date is None:
|
| 102 |
+
assert (start_date is not None) and (end_date is not None), "start_date and end_date must both be provided"
|
| 103 |
+
self.start_date = start_date
|
| 104 |
+
self.end_date = end_date
|
| 105 |
+
self.bounds = geographic_bounds
|
| 106 |
+
|
| 107 |
+
# Lazy computed
|
| 108 |
+
self._available_images = None
|
| 109 |
+
|
| 110 |
+
# Start getting info from google cloud
|
| 111 |
+
if satellite_name:
|
| 112 |
+
self.image_collection = ee.ImageCollection(self.satellite_name)
|
| 113 |
+
if self.bounds:
|
| 114 |
+
self.filterBounds(self.bounds)
|
| 115 |
+
if self.start_date is not None:
|
| 116 |
+
self.filterDate(self.start_date, self.end_date)
|
| 117 |
+
self.scale = scale
|
| 118 |
+
self.crs = crs
|
| 119 |
+
self.image_list = None
|
| 120 |
+
self._df_image_list = None
|
| 121 |
+
self.image_collection_info = None
|
| 122 |
+
self._date_range = None
|
| 123 |
+
self.date_filter_change = False
|
| 124 |
+
self._count = None
|
| 125 |
+
|
| 126 |
+
# Bool for caching
|
| 127 |
+
self.filter_change = True
|
| 128 |
+
self._describe = None
|
| 129 |
+
|
| 130 |
+
def signal_change(func):
|
| 131 |
+
"""Signals that additional filtering was performed. To be used
|
| 132 |
+
as a decorator."""
|
| 133 |
+
@functools.wraps(func)
|
| 134 |
+
def wrap(self, *args, **kwargs):
|
| 135 |
+
self.filter_change = True
|
| 136 |
+
self.date_filter_change = True
|
| 137 |
+
return func(self, *args, **kwargs)
|
| 138 |
+
return wrap
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
@func_set_timeout(10)
|
| 142 |
+
def _get_timeout_info(instance: Any):
|
| 143 |
+
"""Runs getInfo on anything that is passed, with a timeout."""
|
| 144 |
+
return instance.getInfo()
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def _authenticate_gee():
|
| 148 |
+
"""Authenticates earth engine if needed, and initializes."""
|
| 149 |
+
try:
|
| 150 |
+
ee.Initialize()
|
| 151 |
+
except Exception as e:
|
| 152 |
+
# Trigger the authentication flow.
|
| 153 |
+
ee.Authenticate()
|
| 154 |
+
# Initialize the library.
|
| 155 |
+
ee.Initialize()
|
| 156 |
+
|
| 157 |
+
def filter(self, ee_filter: ee.Filter):
|
| 158 |
+
"""Applies a filter to the image_collection attribute. This can be useful for example
|
| 159 |
+
to filter out clouds
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
ee_filter: Filter to apply, must be an instance of ee.Filter.
|
| 163 |
+
|
| 164 |
+
Returns: self, for operation chaining as possible with the earth engine API.
|
| 165 |
+
|
| 166 |
+
"""
|
| 167 |
+
self.image_collection = self.image_collection.filter(ee_filter)
|
| 168 |
+
|
| 169 |
+
return self
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def count(self):
|
| 173 |
+
"""Number of images in the ImageCollection"""
|
| 174 |
+
if self.filter_change or self._count is None:
|
| 175 |
+
self._count = self._get_timeout_info(self.image_collection.size())
|
| 176 |
+
self.filter_change = False
|
| 177 |
+
return self._count
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def available_images(self):
|
| 181 |
+
"""Gets the ImageCollection info"""
|
| 182 |
+
if self.filter_change or self._available_images is None:
|
| 183 |
+
self._available_images = self._get_timeout_info(self.image_collection)
|
| 184 |
+
return self._available_images
|
| 185 |
+
|
| 186 |
+
@signal_change
|
| 187 |
+
def filterDate(self, *args, **kwargs):
|
| 188 |
+
"""Wrapper for the filterDate method in earth engine on the ImageCollection"""
|
| 189 |
+
self.image_collection = self.image_collection.filterDate(*args, **kwargs)
|
| 190 |
+
return self
|
| 191 |
+
|
| 192 |
+
@signal_change
|
| 193 |
+
def getRegion(self, *args, **kwargs):
|
| 194 |
+
"""Wrapper for the getRegion method in earth engine on the ImageCollection.
|
| 195 |
+
Caveat! getRegion does not return an image collection, so the image_list attribute gets
|
| 196 |
+
updated instead of the image_collection attribute. However, the instance of the DataLoader class
|
| 197 |
+
is still returned, so this could be chained with another method on ImageCollection, which wouldn't be
|
| 198 |
+
possible using earth engine.
|
| 199 |
+
"""
|
| 200 |
+
self.image_list = self.image_collection.getRegion(*args, **kwargs)
|
| 201 |
+
return self
|
| 202 |
+
|
| 203 |
+
@signal_change
|
| 204 |
+
def filterBounds(self, geometry, *args, **kwargs):
|
| 205 |
+
"""Wrapper for the filterBounds method in earth engine on the ImageCollection"""
|
| 206 |
+
self.image_collection = self.image_collection.filterBounds(geometry, *args, **kwargs)
|
| 207 |
+
self.bounds = geometry
|
| 208 |
+
return self
|
| 209 |
+
|
| 210 |
+
@signal_change
|
| 211 |
+
def select(self, *bands, **kwargs):
|
| 212 |
+
"""Wrapper for the select method in earth engine on the ImageCollection"""
|
| 213 |
+
self.image_collection = self.image_collection.select(*bands, **kwargs)
|
| 214 |
+
self.bands = list(set(self.bands) | set(bands)) # Unique bands
|
| 215 |
+
return self
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def date_range(self):
|
| 219 |
+
"""Gets the actual date range of the images in the image collection."""
|
| 220 |
+
if self.date_filter_change or self._date_range is None:
|
| 221 |
+
date_range = self.image_collection.reduceColumns(ee.Reducer.minMax(), ["system:time_start"]).getInfo()
|
| 222 |
+
self._date_range = {key: datetime.fromtimestamp(value/1e3) for key, value in date_range.items()}
|
| 223 |
+
self.date_filter_change = False
|
| 224 |
+
return self._date_range
|
| 225 |
+
|
| 226 |
+
@property
|
| 227 |
+
def region(self):
|
| 228 |
+
"""Gets a time series as a pandas DataFrame of the band values for the specified region."""
|
| 229 |
+
if self.filter_change:
|
| 230 |
+
if self.image_list is None:
|
| 231 |
+
self.getRegion()
|
| 232 |
+
res_list = self._get_timeout_info(self.image_list)
|
| 233 |
+
df = pd.DataFrame(res_list[1:], columns=res_list[0])
|
| 234 |
+
df.loc[:, "time"] = pd.to_datetime(df.loc[:, "time"], unit="ms")
|
| 235 |
+
self._df_image_list = df
|
| 236 |
+
self.filter_change = False
|
| 237 |
+
return self._df_image_list
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def collection_info(self):
|
| 241 |
+
"""Runs getInfo on the image collection (the first time the next time the previously
|
| 242 |
+
populated attribute will be returned)."""
|
| 243 |
+
if self.count > 5000:
|
| 244 |
+
raise Exception("Too many images to load. Try filtering more")
|
| 245 |
+
if self.filter_change or self.image_collection_info is None:
|
| 246 |
+
self.image_collection_info = self._get_timeout_info(self.image_collection)
|
| 247 |
+
return self.image_collection_info
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def image_ids(self):
|
| 251 |
+
"""list of names of available images in the image collection"""
|
| 252 |
+
return [i["id"] for i in self.collection_info["features"]]
|
| 253 |
+
|
| 254 |
+
def __repr__(self):
|
| 255 |
+
try:
|
| 256 |
+
return f"""
|
| 257 |
+
Size: {self.count}
|
| 258 |
+
|
| 259 |
+
Dataset date ranges:
|
| 260 |
+
From: {self.date_range["min"]}
|
| 261 |
+
To: {self.date_range["max"]}
|
| 262 |
+
|
| 263 |
+
Selected bands:
|
| 264 |
+
{self.bands}
|
| 265 |
+
|
| 266 |
+
"""
|
| 267 |
+
except Exception as e:
|
| 268 |
+
raise Exception("Impossible to represent the dataset. Try filtering more. Error handling to do.")
|
| 269 |
+
|
| 270 |
+
def reproject(self, image, **kwargs):
|
| 271 |
+
def resolve(name: str):
|
| 272 |
+
# Resolve crs
|
| 273 |
+
if name in kwargs:
|
| 274 |
+
item = kwargs[name]
|
| 275 |
+
elif getattr(self, name):
|
| 276 |
+
item = getattr(self, name)
|
| 277 |
+
else:
|
| 278 |
+
item = None
|
| 279 |
+
return item
|
| 280 |
+
crs = resolve("crs")
|
| 281 |
+
scale = resolve("scale")
|
| 282 |
+
if crs is not None or scale is not None:
|
| 283 |
+
image = image.reproject(crs, None, scale)
|
| 284 |
+
return image
|
| 285 |
+
|
| 286 |
+
def download_image(self, image_id: str, **kwargs):
|
| 287 |
+
"""Downloads an image based on its id / name. The additional arguments are passed
|
| 288 |
+
to getThumbUrl, and could be scale, max, min...
|
| 289 |
+
"""
|
| 290 |
+
img = ee.Image(image_id).select(*self.bands)
|
| 291 |
+
img = self.reproject(img, **kwargs)
|
| 292 |
+
input_args = {'region': self.bounds}
|
| 293 |
+
input_args.update(**kwargs)
|
| 294 |
+
all_bands = self.collection_info["features"][0]["bands"]
|
| 295 |
+
selected_bands = [band for i, band in enumerate(all_bands) if all_bands[i]["id"] in self.bands]
|
| 296 |
+
if "min" not in input_args:
|
| 297 |
+
input_args.update({"min": selected_bands[0]["data_type"]["min"]})
|
| 298 |
+
if "max" not in input_args:
|
| 299 |
+
input_args.update({"max": selected_bands[0]["data_type"]["max"]})
|
| 300 |
+
url = img.getThumbUrl(input_args)
|
| 301 |
+
buffer = tempfile.SpooledTemporaryFile(max_size=1e9)
|
| 302 |
+
r = requests.get(url, stream=True)
|
| 303 |
+
if r.status_code == 200:
|
| 304 |
+
downloaded = 0
|
| 305 |
+
# filesize = int(r.headers['content-length'])
|
| 306 |
+
for chunk in r.iter_content(chunk_size=1024):
|
| 307 |
+
downloaded += len(chunk)
|
| 308 |
+
buffer.write(chunk)
|
| 309 |
+
buffer.seek(0)
|
| 310 |
+
img = Image.open(io.BytesIO(buffer.read()))
|
| 311 |
+
buffer.close()
|
| 312 |
+
return img
|
| 313 |
+
|
| 314 |
+
@staticmethod
|
| 315 |
+
def _regex(regex: str, im_id_list: List[str], include: bool) -> list:
|
| 316 |
+
"""
|
| 317 |
+
Filters the im_id_list based on a regular expression. This is useful before downloading
|
| 318 |
+
a collection of images. For example, using (.*)TXT with include=True will only download images
|
| 319 |
+
that end with TXT, wich for Nantes means filtering out empty or half empty images.
|
| 320 |
+
Args:
|
| 321 |
+
regex: python regex as a strng
|
| 322 |
+
im_id_list: list, image id list
|
| 323 |
+
include: whether to include or exclude elements that match the regex.
|
| 324 |
+
|
| 325 |
+
Returns: filtered list.
|
| 326 |
+
|
| 327 |
+
"""
|
| 328 |
+
expression = "re.match('{regex}', '{im_id}') is not None"
|
| 329 |
+
if not include:
|
| 330 |
+
expression = "not " + expression
|
| 331 |
+
filtered_list = list()
|
| 332 |
+
for im_id in im_id_list:
|
| 333 |
+
if eval(expression.format(regex=regex, im_id=im_id)):
|
| 334 |
+
filtered_list.append(im_id)
|
| 335 |
+
return filtered_list
|
| 336 |
+
|
| 337 |
+
def download_all_images(self, regex_exclude: str = None, regex_include: str = None, **kwargs):
|
| 338 |
+
"""
|
| 339 |
+
Runs download_image in a for loop around the available images.
|
| 340 |
+
Makes it possible to filter images to download based on a regex.
|
| 341 |
+
Args:
|
| 342 |
+
regex_exclude: any image that matches this regex will be excluded.
|
| 343 |
+
regex_include: any image that matches this regex will be included
|
| 344 |
+
**kwargs: arguments to be passed to getThumbUrl
|
| 345 |
+
|
| 346 |
+
Returns: list of PIL images
|
| 347 |
+
"""
|
| 348 |
+
images = list()
|
| 349 |
+
image_ids = self.image_ids
|
| 350 |
+
if regex_exclude is not None:
|
| 351 |
+
image_ids = self._regex(regex_exclude, image_ids, include=False)
|
| 352 |
+
if regex_include is not None:
|
| 353 |
+
image_ids = self._regex(regex_include, image_ids, include=True)
|
| 354 |
+
for i in tqdm(range(len(image_ids))):
|
| 355 |
+
images.append(self.download_image(image_ids[i], **kwargs))
|
| 356 |
+
return images
|
biomap/dino/utils.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Misc functions.
|
| 16 |
+
|
| 17 |
+
Mostly copy-paste from torchvision references or other public repos like DETR:
|
| 18 |
+
https://github.com/facebookresearch/detr/blob/master/util/misc.py
|
| 19 |
+
"""
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
import math
|
| 24 |
+
import random
|
| 25 |
+
import datetime
|
| 26 |
+
import subprocess
|
| 27 |
+
from collections import defaultdict, deque
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
from torch import nn
|
| 32 |
+
import torch.distributed as dist
|
| 33 |
+
from PIL import ImageFilter, ImageOps
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GaussianBlur(object):
|
| 37 |
+
"""
|
| 38 |
+
Apply Gaussian Blur to the PIL image.
|
| 39 |
+
"""
|
| 40 |
+
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
|
| 41 |
+
self.prob = p
|
| 42 |
+
self.radius_min = radius_min
|
| 43 |
+
self.radius_max = radius_max
|
| 44 |
+
|
| 45 |
+
def __call__(self, img):
|
| 46 |
+
do_it = random.random() <= self.prob
|
| 47 |
+
if not do_it:
|
| 48 |
+
return img
|
| 49 |
+
|
| 50 |
+
return img.filter(
|
| 51 |
+
ImageFilter.GaussianBlur(
|
| 52 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Solarization(object):
|
| 58 |
+
"""
|
| 59 |
+
Apply Solarization to the PIL image.
|
| 60 |
+
"""
|
| 61 |
+
def __init__(self, p):
|
| 62 |
+
self.p = p
|
| 63 |
+
|
| 64 |
+
def __call__(self, img):
|
| 65 |
+
if random.random() < self.p:
|
| 66 |
+
return ImageOps.solarize(img)
|
| 67 |
+
else:
|
| 68 |
+
return img
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
|
| 72 |
+
if os.path.isfile(pretrained_weights):
|
| 73 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 74 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
| 75 |
+
print(f"Take key {checkpoint_key} in provided checkpoint dict")
|
| 76 |
+
state_dict = state_dict[checkpoint_key]
|
| 77 |
+
# remove `module.` prefix
|
| 78 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 79 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
| 80 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 81 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 82 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
|
| 83 |
+
else:
|
| 84 |
+
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
|
| 85 |
+
url = None
|
| 86 |
+
if model_name == "vit_small" and patch_size == 16:
|
| 87 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
| 88 |
+
elif model_name == "vit_small" and patch_size == 8:
|
| 89 |
+
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
|
| 90 |
+
elif model_name == "vit_base" and patch_size == 16:
|
| 91 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
| 92 |
+
elif model_name == "vit_base" and patch_size == 8:
|
| 93 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
| 94 |
+
if url is not None:
|
| 95 |
+
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
| 96 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
| 97 |
+
model.load_state_dict(state_dict, strict=True)
|
| 98 |
+
else:
|
| 99 |
+
print("There is no reference weights available for this model => We use random weights.")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def clip_gradients(model, clip):
|
| 103 |
+
norms = []
|
| 104 |
+
for name, p in model.named_parameters():
|
| 105 |
+
if p.grad is not None:
|
| 106 |
+
param_norm = p.grad.data.norm(2)
|
| 107 |
+
norms.append(param_norm.item())
|
| 108 |
+
clip_coef = clip / (param_norm + 1e-6)
|
| 109 |
+
if clip_coef < 1:
|
| 110 |
+
p.grad.data.mul_(clip_coef)
|
| 111 |
+
return norms
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
| 115 |
+
if epoch >= freeze_last_layer:
|
| 116 |
+
return
|
| 117 |
+
for n, p in model.named_parameters():
|
| 118 |
+
if "last_layer" in n:
|
| 119 |
+
p.grad = None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
|
| 123 |
+
"""
|
| 124 |
+
Re-start from checkpoint
|
| 125 |
+
"""
|
| 126 |
+
if not os.path.isfile(ckp_path):
|
| 127 |
+
return
|
| 128 |
+
print("Found checkpoint at {}".format(ckp_path))
|
| 129 |
+
|
| 130 |
+
# open checkpoint file
|
| 131 |
+
checkpoint = torch.load(ckp_path, map_location="cpu")
|
| 132 |
+
|
| 133 |
+
# key is what to look for in the checkpoint file
|
| 134 |
+
# value is the object to load
|
| 135 |
+
# example: {'state_dict': model}
|
| 136 |
+
for key, value in kwargs.items():
|
| 137 |
+
if key in checkpoint and value is not None:
|
| 138 |
+
try:
|
| 139 |
+
msg = value.load_state_dict(checkpoint[key], strict=False)
|
| 140 |
+
print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
|
| 141 |
+
except TypeError:
|
| 142 |
+
try:
|
| 143 |
+
msg = value.load_state_dict(checkpoint[key])
|
| 144 |
+
print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
|
| 145 |
+
except ValueError:
|
| 146 |
+
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
|
| 147 |
+
else:
|
| 148 |
+
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
|
| 149 |
+
|
| 150 |
+
# re load variable important for the run
|
| 151 |
+
if run_variables is not None:
|
| 152 |
+
for var_name in run_variables:
|
| 153 |
+
if var_name in checkpoint:
|
| 154 |
+
run_variables[var_name] = checkpoint[var_name]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
|
| 158 |
+
warmup_schedule = np.array([])
|
| 159 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
| 160 |
+
if warmup_epochs > 0:
|
| 161 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 162 |
+
|
| 163 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
| 164 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
| 165 |
+
|
| 166 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
| 167 |
+
assert len(schedule) == epochs * niter_per_ep
|
| 168 |
+
return schedule
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def bool_flag(s):
|
| 172 |
+
"""
|
| 173 |
+
Parse boolean arguments from the command line.
|
| 174 |
+
"""
|
| 175 |
+
FALSY_STRINGS = {"off", "false", "0"}
|
| 176 |
+
TRUTHY_STRINGS = {"on", "true", "1"}
|
| 177 |
+
if s.lower() in FALSY_STRINGS:
|
| 178 |
+
return False
|
| 179 |
+
elif s.lower() in TRUTHY_STRINGS:
|
| 180 |
+
return True
|
| 181 |
+
else:
|
| 182 |
+
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def fix_random_seeds(seed=31):
|
| 186 |
+
"""
|
| 187 |
+
Fix random seeds.
|
| 188 |
+
"""
|
| 189 |
+
torch.manual_seed(seed)
|
| 190 |
+
torch.cuda.manual_seed_all(seed)
|
| 191 |
+
np.random.seed(seed)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class SmoothedValue(object):
|
| 195 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 196 |
+
window or the global series average.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, window_size=20, fmt=None):
|
| 200 |
+
if fmt is None:
|
| 201 |
+
fmt = "{median:.6f} ({global_avg:.6f})"
|
| 202 |
+
self.deque = deque(maxlen=window_size)
|
| 203 |
+
self.total = 0.0
|
| 204 |
+
self.count = 0
|
| 205 |
+
self.fmt = fmt
|
| 206 |
+
|
| 207 |
+
def update(self, value, n=1):
|
| 208 |
+
self.deque.append(value)
|
| 209 |
+
self.count += n
|
| 210 |
+
self.total += value * n
|
| 211 |
+
|
| 212 |
+
def synchronize_between_processes(self):
|
| 213 |
+
"""
|
| 214 |
+
Warning: does not synchronize the deque!
|
| 215 |
+
"""
|
| 216 |
+
if not is_dist_avail_and_initialized():
|
| 217 |
+
return
|
| 218 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 219 |
+
dist.barrier()
|
| 220 |
+
dist.all_reduce(t)
|
| 221 |
+
t = t.tolist()
|
| 222 |
+
self.count = int(t[0])
|
| 223 |
+
self.total = t[1]
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def median(self):
|
| 227 |
+
d = torch.tensor(list(self.deque))
|
| 228 |
+
return d.median().item()
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def avg(self):
|
| 232 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 233 |
+
return d.mean().item()
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def global_avg(self):
|
| 237 |
+
return self.total / self.count
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def max(self):
|
| 241 |
+
return max(self.deque)
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def value(self):
|
| 245 |
+
return self.deque[-1]
|
| 246 |
+
|
| 247 |
+
def __str__(self):
|
| 248 |
+
return self.fmt.format(
|
| 249 |
+
median=self.median,
|
| 250 |
+
avg=self.avg,
|
| 251 |
+
global_avg=self.global_avg,
|
| 252 |
+
max=self.max,
|
| 253 |
+
value=self.value)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def reduce_dict(input_dict, average=True):
|
| 257 |
+
"""
|
| 258 |
+
Args:
|
| 259 |
+
input_dict (dict): all the values will be reduced
|
| 260 |
+
average (bool): whether to do average or sum
|
| 261 |
+
Reduce the values in the dictionary from all processes so that all processes
|
| 262 |
+
have the averaged results. Returns a dict with the same fields as
|
| 263 |
+
input_dict, after reduction.
|
| 264 |
+
"""
|
| 265 |
+
world_size = get_world_size()
|
| 266 |
+
if world_size < 2:
|
| 267 |
+
return input_dict
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
names = []
|
| 270 |
+
values = []
|
| 271 |
+
# sort the keys so that they are consistent across processes
|
| 272 |
+
for k in sorted(input_dict.keys()):
|
| 273 |
+
names.append(k)
|
| 274 |
+
values.append(input_dict[k])
|
| 275 |
+
values = torch.stack(values, dim=0)
|
| 276 |
+
dist.all_reduce(values)
|
| 277 |
+
if average:
|
| 278 |
+
values /= world_size
|
| 279 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
| 280 |
+
return reduced_dict
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class MetricLogger(object):
|
| 284 |
+
def __init__(self, delimiter="\t"):
|
| 285 |
+
self.meters = defaultdict(SmoothedValue)
|
| 286 |
+
self.delimiter = delimiter
|
| 287 |
+
|
| 288 |
+
def update(self, **kwargs):
|
| 289 |
+
for k, v in kwargs.items():
|
| 290 |
+
if isinstance(v, torch.Tensor):
|
| 291 |
+
v = v.item()
|
| 292 |
+
assert isinstance(v, (float, int))
|
| 293 |
+
self.meters[k].update(v)
|
| 294 |
+
|
| 295 |
+
def __getattr__(self, attr):
|
| 296 |
+
if attr in self.meters:
|
| 297 |
+
return self.meters[attr]
|
| 298 |
+
if attr in self.__dict__:
|
| 299 |
+
return self.__dict__[attr]
|
| 300 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 301 |
+
type(self).__name__, attr))
|
| 302 |
+
|
| 303 |
+
def __str__(self):
|
| 304 |
+
loss_str = []
|
| 305 |
+
for name, meter in self.meters.items():
|
| 306 |
+
loss_str.append(
|
| 307 |
+
"{}: {}".format(name, str(meter))
|
| 308 |
+
)
|
| 309 |
+
return self.delimiter.join(loss_str)
|
| 310 |
+
|
| 311 |
+
def synchronize_between_processes(self):
|
| 312 |
+
for meter in self.meters.values():
|
| 313 |
+
meter.synchronize_between_processes()
|
| 314 |
+
|
| 315 |
+
def add_meter(self, name, meter):
|
| 316 |
+
self.meters[name] = meter
|
| 317 |
+
|
| 318 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 319 |
+
i = 0
|
| 320 |
+
if not header:
|
| 321 |
+
header = ''
|
| 322 |
+
start_time = time.time()
|
| 323 |
+
end = time.time()
|
| 324 |
+
iter_time = SmoothedValue(fmt='{avg:.6f}')
|
| 325 |
+
data_time = SmoothedValue(fmt='{avg:.6f}')
|
| 326 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 327 |
+
if torch.cuda.is_available():
|
| 328 |
+
log_msg = self.delimiter.join([
|
| 329 |
+
header,
|
| 330 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 331 |
+
'eta: {eta}',
|
| 332 |
+
'{meters}',
|
| 333 |
+
'time: {time}',
|
| 334 |
+
'data: {data}',
|
| 335 |
+
'max mem: {memory:.0f}'
|
| 336 |
+
])
|
| 337 |
+
else:
|
| 338 |
+
log_msg = self.delimiter.join([
|
| 339 |
+
header,
|
| 340 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 341 |
+
'eta: {eta}',
|
| 342 |
+
'{meters}',
|
| 343 |
+
'time: {time}',
|
| 344 |
+
'data: {data}'
|
| 345 |
+
])
|
| 346 |
+
MB = 1024.0 * 1024.0
|
| 347 |
+
for obj in iterable:
|
| 348 |
+
data_time.update(time.time() - end)
|
| 349 |
+
yield obj
|
| 350 |
+
iter_time.update(time.time() - end)
|
| 351 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 352 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 353 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 354 |
+
if torch.cuda.is_available():
|
| 355 |
+
print(log_msg.format(
|
| 356 |
+
i, len(iterable), eta=eta_string,
|
| 357 |
+
meters=str(self),
|
| 358 |
+
time=str(iter_time), data=str(data_time),
|
| 359 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 360 |
+
else:
|
| 361 |
+
print(log_msg.format(
|
| 362 |
+
i, len(iterable), eta=eta_string,
|
| 363 |
+
meters=str(self),
|
| 364 |
+
time=str(iter_time), data=str(data_time)))
|
| 365 |
+
i += 1
|
| 366 |
+
end = time.time()
|
| 367 |
+
total_time = time.time() - start_time
|
| 368 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 369 |
+
print('{} Total time: {} ({:.6f} s / it)'.format(
|
| 370 |
+
header, total_time_str, total_time / len(iterable)))
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def get_sha():
|
| 374 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 375 |
+
|
| 376 |
+
def _run(command):
|
| 377 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
| 378 |
+
sha = 'N/A'
|
| 379 |
+
diff = "clean"
|
| 380 |
+
branch = 'N/A'
|
| 381 |
+
try:
|
| 382 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
| 383 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
| 384 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
| 385 |
+
diff = "has uncommited changes" if diff else "clean"
|
| 386 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
| 387 |
+
except Exception:
|
| 388 |
+
pass
|
| 389 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 390 |
+
return message
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def is_dist_avail_and_initialized():
|
| 394 |
+
if not dist.is_available():
|
| 395 |
+
return False
|
| 396 |
+
if not dist.is_initialized():
|
| 397 |
+
return False
|
| 398 |
+
return True
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def get_world_size():
|
| 402 |
+
if not is_dist_avail_and_initialized():
|
| 403 |
+
return 1
|
| 404 |
+
return dist.get_world_size()
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def get_rank():
|
| 408 |
+
if not is_dist_avail_and_initialized():
|
| 409 |
+
return 0
|
| 410 |
+
return dist.get_rank()
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def is_main_process():
|
| 414 |
+
return get_rank() == 0
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def save_on_master(*args, **kwargs):
|
| 418 |
+
if is_main_process():
|
| 419 |
+
torch.save(*args, **kwargs)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def setup_for_distributed(is_master):
|
| 423 |
+
"""
|
| 424 |
+
This function disables printing when not in master process
|
| 425 |
+
"""
|
| 426 |
+
import builtins as __builtin__
|
| 427 |
+
builtin_print = __builtin__.print
|
| 428 |
+
|
| 429 |
+
def print(*args, **kwargs):
|
| 430 |
+
force = kwargs.pop('force', False)
|
| 431 |
+
if is_master or force:
|
| 432 |
+
builtin_print(*args, **kwargs)
|
| 433 |
+
|
| 434 |
+
__builtin__.print = print
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def init_distributed_mode(args):
|
| 438 |
+
# launched with torch.distributed.launch
|
| 439 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 440 |
+
args.rank = int(os.environ["RANK"])
|
| 441 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 442 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 443 |
+
# launched with submitit on a slurm cluster
|
| 444 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 445 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 446 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 447 |
+
# launched naively with `python main_dino.py`
|
| 448 |
+
# we manually add MASTER_ADDR and MASTER_PORT to env variables
|
| 449 |
+
elif torch.cuda.is_available():
|
| 450 |
+
print('Will run the code on one GPU.')
|
| 451 |
+
args.rank, args.gpu, args.world_size = 0, 0, 1
|
| 452 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
| 453 |
+
os.environ['MASTER_PORT'] = '29500'
|
| 454 |
+
else:
|
| 455 |
+
print('Does not support training without GPU.')
|
| 456 |
+
sys.exit(1)
|
| 457 |
+
|
| 458 |
+
dist.init_process_group(
|
| 459 |
+
backend="nccl",
|
| 460 |
+
init_method=args.dist_url,
|
| 461 |
+
world_size=args.world_size,
|
| 462 |
+
rank=args.rank,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
torch.cuda.set_device(args.gpu)
|
| 466 |
+
print('| distributed init (rank {}): {}'.format(
|
| 467 |
+
args.rank, args.dist_url), flush=True)
|
| 468 |
+
dist.barrier()
|
| 469 |
+
setup_for_distributed(args.rank == 0)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def accuracy(output, target, topk=(1,)):
|
| 473 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
| 474 |
+
maxk = max(topk)
|
| 475 |
+
batch_size = target.size(0)
|
| 476 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 477 |
+
pred = pred.t()
|
| 478 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
| 479 |
+
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 483 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 484 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 485 |
+
def norm_cdf(x):
|
| 486 |
+
# Computes standard normal cumulative distribution function
|
| 487 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 488 |
+
|
| 489 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 490 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 491 |
+
"The distribution of values may be incorrect.",
|
| 492 |
+
stacklevel=2)
|
| 493 |
+
|
| 494 |
+
with torch.no_grad():
|
| 495 |
+
# Values are generated by using a truncated uniform distribution and
|
| 496 |
+
# then using the inverse CDF for the normal distribution.
|
| 497 |
+
# Get upper and lower cdf values
|
| 498 |
+
l = norm_cdf((a - mean) / std)
|
| 499 |
+
u = norm_cdf((b - mean) / std)
|
| 500 |
+
|
| 501 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 502 |
+
# [2l-1, 2u-1].
|
| 503 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 504 |
+
|
| 505 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 506 |
+
# standard normal
|
| 507 |
+
tensor.erfinv_()
|
| 508 |
+
|
| 509 |
+
# Transform to proper mean, std
|
| 510 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 511 |
+
tensor.add_(mean)
|
| 512 |
+
|
| 513 |
+
# Clamp to ensure it's in the proper range
|
| 514 |
+
tensor.clamp_(min=a, max=b)
|
| 515 |
+
return tensor
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 519 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
| 520 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class LARS(torch.optim.Optimizer):
|
| 524 |
+
"""
|
| 525 |
+
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
|
| 526 |
+
"""
|
| 527 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
|
| 528 |
+
weight_decay_filter=None, lars_adaptation_filter=None):
|
| 529 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
|
| 530 |
+
eta=eta, weight_decay_filter=weight_decay_filter,
|
| 531 |
+
lars_adaptation_filter=lars_adaptation_filter)
|
| 532 |
+
super().__init__(params, defaults)
|
| 533 |
+
|
| 534 |
+
@torch.no_grad()
|
| 535 |
+
def step(self):
|
| 536 |
+
for g in self.param_groups:
|
| 537 |
+
for p in g['params']:
|
| 538 |
+
dp = p.grad
|
| 539 |
+
|
| 540 |
+
if dp is None:
|
| 541 |
+
continue
|
| 542 |
+
|
| 543 |
+
if p.ndim != 1:
|
| 544 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
| 545 |
+
|
| 546 |
+
if p.ndim != 1:
|
| 547 |
+
param_norm = torch.norm(p)
|
| 548 |
+
update_norm = torch.norm(dp)
|
| 549 |
+
one = torch.ones_like(param_norm)
|
| 550 |
+
q = torch.where(param_norm > 0.,
|
| 551 |
+
torch.where(update_norm > 0,
|
| 552 |
+
(g['eta'] * param_norm / update_norm), one), one)
|
| 553 |
+
dp = dp.mul(q)
|
| 554 |
+
|
| 555 |
+
param_state = self.state[p]
|
| 556 |
+
if 'mu' not in param_state:
|
| 557 |
+
param_state['mu'] = torch.zeros_like(p)
|
| 558 |
+
mu = param_state['mu']
|
| 559 |
+
mu.mul_(g['momentum']).add_(dp)
|
| 560 |
+
|
| 561 |
+
p.add_(mu, alpha=-g['lr'])
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class MultiCropWrapper(nn.Module):
|
| 565 |
+
"""
|
| 566 |
+
Perform forward pass separately on each resolution input.
|
| 567 |
+
The inputs corresponding to a single resolution are clubbed and single
|
| 568 |
+
forward is run on the same resolution inputs. Hence we do several
|
| 569 |
+
forward passes = number of different resolutions used. We then
|
| 570 |
+
concatenate all the output features and run the head forward on these
|
| 571 |
+
concatenated features.
|
| 572 |
+
"""
|
| 573 |
+
def __init__(self, backbone, head):
|
| 574 |
+
super(MultiCropWrapper, self).__init__()
|
| 575 |
+
# disable layers dedicated to ImageNet labels classification
|
| 576 |
+
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
|
| 577 |
+
self.backbone = backbone
|
| 578 |
+
self.head = head
|
| 579 |
+
|
| 580 |
+
def forward(self, x):
|
| 581 |
+
# convert to list
|
| 582 |
+
if not isinstance(x, list):
|
| 583 |
+
x = [x]
|
| 584 |
+
idx_crops = torch.cumsum(torch.unique_consecutive(
|
| 585 |
+
torch.tensor([inp.shape[-1] for inp in x]),
|
| 586 |
+
return_counts=True,
|
| 587 |
+
)[1], 0)
|
| 588 |
+
start_idx = 0
|
| 589 |
+
for end_idx in idx_crops:
|
| 590 |
+
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
|
| 591 |
+
if start_idx == 0:
|
| 592 |
+
output = _out
|
| 593 |
+
else:
|
| 594 |
+
output = torch.cat((output, _out))
|
| 595 |
+
start_idx = end_idx
|
| 596 |
+
# Run the head forward on the concatenated features.
|
| 597 |
+
return self.head(output)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def get_params_groups(model):
|
| 601 |
+
regularized = []
|
| 602 |
+
not_regularized = []
|
| 603 |
+
for name, param in model.named_parameters():
|
| 604 |
+
if not param.requires_grad:
|
| 605 |
+
continue
|
| 606 |
+
# we do not regularize biases nor Norm parameters
|
| 607 |
+
if name.endswith(".bias") or len(param.shape) == 1:
|
| 608 |
+
not_regularized.append(param)
|
| 609 |
+
else:
|
| 610 |
+
regularized.append(param)
|
| 611 |
+
return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def has_batchnorms(model):
|
| 615 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| 616 |
+
for name, module in model.named_modules():
|
| 617 |
+
if isinstance(module, bn_types):
|
| 618 |
+
return True
|
| 619 |
+
return False
|
biomap/dino/vision_transformer.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Mostly copy-paste from timm library.
|
| 16 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 17 |
+
"""
|
| 18 |
+
import math
|
| 19 |
+
from functools import partial
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from dino.utils import trunc_normal_
|
| 24 |
+
|
| 25 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 26 |
+
if drop_prob == 0. or not training:
|
| 27 |
+
return x
|
| 28 |
+
keep_prob = 1 - drop_prob
|
| 29 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 30 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 31 |
+
random_tensor.floor_() # binarize
|
| 32 |
+
output = x.div(keep_prob) * random_tensor
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class DropPath(nn.Module):
|
| 37 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, drop_prob=None):
|
| 40 |
+
super(DropPath, self).__init__()
|
| 41 |
+
self.drop_prob = drop_prob
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Mlp(nn.Module):
|
| 48 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 49 |
+
super().__init__()
|
| 50 |
+
out_features = out_features or in_features
|
| 51 |
+
hidden_features = hidden_features or in_features
|
| 52 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 53 |
+
self.act = act_layer()
|
| 54 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 55 |
+
self.drop = nn.Dropout(drop)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
x = self.fc1(x)
|
| 59 |
+
x = self.act(x)
|
| 60 |
+
x = self.drop(x)
|
| 61 |
+
x = self.fc2(x)
|
| 62 |
+
x = self.drop(x)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Attention(nn.Module):
|
| 67 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.num_heads = num_heads
|
| 70 |
+
head_dim = dim // num_heads
|
| 71 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 72 |
+
|
| 73 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 74 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 75 |
+
self.proj = nn.Linear(dim, dim)
|
| 76 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 77 |
+
|
| 78 |
+
def forward(self, x, return_qkv=False):
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 81 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 82 |
+
|
| 83 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 84 |
+
attn = attn.softmax(dim=-1)
|
| 85 |
+
attn = self.attn_drop(attn)
|
| 86 |
+
|
| 87 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 88 |
+
x = self.proj(x)
|
| 89 |
+
x = self.proj_drop(x)
|
| 90 |
+
return x,attn, qkv
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Block(nn.Module):
|
| 95 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 96 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.norm1 = norm_layer(dim)
|
| 99 |
+
self.attn = Attention(
|
| 100 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 101 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 102 |
+
self.norm2 = norm_layer(dim)
|
| 103 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 104 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 105 |
+
|
| 106 |
+
def forward(self, x, return_attention=False, return_qkv = False):
|
| 107 |
+
y, attn, qkv = self.attn(self.norm1(x))
|
| 108 |
+
if return_attention:
|
| 109 |
+
return attn
|
| 110 |
+
x = x + self.drop_path(y)
|
| 111 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 112 |
+
if return_qkv:
|
| 113 |
+
return x,attn, qkv
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class PatchEmbed(nn.Module):
|
| 118 |
+
""" Image to Patch Embedding
|
| 119 |
+
"""
|
| 120 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
| 121 |
+
super().__init__()
|
| 122 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
| 123 |
+
self.img_size = img_size
|
| 124 |
+
self.patch_size = patch_size
|
| 125 |
+
self.num_patches = num_patches
|
| 126 |
+
|
| 127 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
B, C, H, W = x.shape
|
| 131 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class VisionTransformer(nn.Module):
|
| 136 |
+
""" Vision Transformer """
|
| 137 |
+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
| 138 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 139 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
self.num_features = self.embed_dim = embed_dim
|
| 143 |
+
|
| 144 |
+
self.patch_embed = PatchEmbed(
|
| 145 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 146 |
+
num_patches = self.patch_embed.num_patches
|
| 147 |
+
|
| 148 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 149 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 150 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 151 |
+
|
| 152 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 153 |
+
self.blocks = nn.ModuleList([
|
| 154 |
+
Block(
|
| 155 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 156 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 157 |
+
for i in range(depth)])
|
| 158 |
+
self.norm = norm_layer(embed_dim)
|
| 159 |
+
|
| 160 |
+
# Classifier head
|
| 161 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 162 |
+
|
| 163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 164 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 165 |
+
self.apply(self._init_weights)
|
| 166 |
+
|
| 167 |
+
def _init_weights(self, m):
|
| 168 |
+
if isinstance(m, nn.Linear):
|
| 169 |
+
trunc_normal_(m.weight, std=.02)
|
| 170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 171 |
+
nn.init.constant_(m.bias, 0)
|
| 172 |
+
elif isinstance(m, nn.LayerNorm):
|
| 173 |
+
nn.init.constant_(m.bias, 0)
|
| 174 |
+
nn.init.constant_(m.weight, 1.0)
|
| 175 |
+
|
| 176 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 177 |
+
npatch = x.shape[1] - 1
|
| 178 |
+
N = self.pos_embed.shape[1] - 1
|
| 179 |
+
if npatch == N and w == h:
|
| 180 |
+
return self.pos_embed
|
| 181 |
+
class_pos_embed = self.pos_embed[:, 0]
|
| 182 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
| 183 |
+
dim = x.shape[-1]
|
| 184 |
+
w0 = w // self.patch_embed.patch_size
|
| 185 |
+
h0 = h // self.patch_embed.patch_size
|
| 186 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 187 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 188 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
| 189 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 190 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 191 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
| 192 |
+
mode='bicubic',
|
| 193 |
+
)
|
| 194 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
| 195 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 196 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 197 |
+
|
| 198 |
+
def prepare_tokens(self, x):
|
| 199 |
+
B, nc, w, h = x.shape
|
| 200 |
+
x = self.patch_embed(x) # patch linear embedding
|
| 201 |
+
|
| 202 |
+
# add the [CLS] token to the embed patch tokens
|
| 203 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 204 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 205 |
+
|
| 206 |
+
# add positional encoding to each token
|
| 207 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 208 |
+
|
| 209 |
+
return self.pos_drop(x)
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
x = self.prepare_tokens(x)
|
| 213 |
+
for blk in self.blocks:
|
| 214 |
+
x = blk(x)
|
| 215 |
+
x = self.norm(x)
|
| 216 |
+
return x[:, 0]
|
| 217 |
+
|
| 218 |
+
def forward_feats(self, x):
|
| 219 |
+
x = self.prepare_tokens(x)
|
| 220 |
+
for blk in self.blocks:
|
| 221 |
+
x = blk(x)
|
| 222 |
+
x = self.norm(x)
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
def get_intermediate_feat(self, x, n=1):
|
| 226 |
+
x = self.prepare_tokens(x)
|
| 227 |
+
# we return the output tokens from the `n` last blocks
|
| 228 |
+
feat = []
|
| 229 |
+
attns = []
|
| 230 |
+
qkvs = []
|
| 231 |
+
for i, blk in enumerate(self.blocks):
|
| 232 |
+
x,attn,qkv = blk(x, return_qkv=True)
|
| 233 |
+
if len(self.blocks) - i <= n:
|
| 234 |
+
feat.append(self.norm(x))
|
| 235 |
+
qkvs.append(qkv)
|
| 236 |
+
attns.append(attn)
|
| 237 |
+
return feat, attns, qkvs
|
| 238 |
+
|
| 239 |
+
def get_last_selfattention(self, x):
|
| 240 |
+
x = self.prepare_tokens(x)
|
| 241 |
+
for i, blk in enumerate(self.blocks):
|
| 242 |
+
if i < len(self.blocks) - 1:
|
| 243 |
+
x = blk(x)
|
| 244 |
+
else:
|
| 245 |
+
# return attention of the last block
|
| 246 |
+
return blk(x, return_attention=True)
|
| 247 |
+
|
| 248 |
+
def get_intermediate_layers(self, x, n=1):
|
| 249 |
+
x = self.prepare_tokens(x)
|
| 250 |
+
# we return the output tokens from the `n` last blocks
|
| 251 |
+
output = []
|
| 252 |
+
for i, blk in enumerate(self.blocks):
|
| 253 |
+
x = blk(x)
|
| 254 |
+
if len(self.blocks) - i <= n:
|
| 255 |
+
output.append(self.norm(x))
|
| 256 |
+
return output
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def vit_tiny(patch_size=16, **kwargs):
|
| 260 |
+
model = VisionTransformer(
|
| 261 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
| 262 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 263 |
+
return model
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def vit_small(patch_size=16, **kwargs):
|
| 267 |
+
model = VisionTransformer(
|
| 268 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
| 269 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 270 |
+
return model
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def vit_base(patch_size=16, **kwargs):
|
| 274 |
+
model = VisionTransformer(
|
| 275 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
| 276 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 277 |
+
return model
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class DINOHead(nn.Module):
|
| 281 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
|
| 282 |
+
super().__init__()
|
| 283 |
+
nlayers = max(nlayers, 1)
|
| 284 |
+
if nlayers == 1:
|
| 285 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
| 286 |
+
else:
|
| 287 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
| 288 |
+
if use_bn:
|
| 289 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 290 |
+
layers.append(nn.GELU())
|
| 291 |
+
for _ in range(nlayers - 2):
|
| 292 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
| 293 |
+
if use_bn:
|
| 294 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 295 |
+
layers.append(nn.GELU())
|
| 296 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
| 297 |
+
self.mlp = nn.Sequential(*layers)
|
| 298 |
+
self.apply(self._init_weights)
|
| 299 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 300 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 301 |
+
if norm_last_layer:
|
| 302 |
+
self.last_layer.weight_g.requires_grad = False
|
| 303 |
+
|
| 304 |
+
def _init_weights(self, m):
|
| 305 |
+
if isinstance(m, nn.Linear):
|
| 306 |
+
trunc_normal_(m.weight, std=.02)
|
| 307 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 308 |
+
nn.init.constant_(m.bias, 0)
|
| 309 |
+
|
| 310 |
+
def forward(self, x):
|
| 311 |
+
x = self.mlp(x)
|
| 312 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
| 313 |
+
x = self.last_layer(x)
|
| 314 |
+
return x
|
biomap/helper.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.multiprocessing
|
| 2 |
+
import torchvision.transforms as T
|
| 3 |
+
import numpy as np
|
| 4 |
+
from utils import transform_to_pil, create_video
|
| 5 |
+
from utils_gee import extract_img, transform_ee_img
|
| 6 |
+
from dateutil.relativedelta import relativedelta
|
| 7 |
+
import datetime
|
| 8 |
+
from dateutil.relativedelta import relativedelta
|
| 9 |
+
import cv2
|
| 10 |
+
|
| 11 |
+
from joblib import Parallel, cpu_count, delayed
|
| 12 |
+
|
| 13 |
+
def get_image(location, d1, d2):
|
| 14 |
+
print(f"getting image for {d1} to {d2}")
|
| 15 |
+
try:
|
| 16 |
+
img = extract_img(location, d1, d2)
|
| 17 |
+
img_test = transform_ee_img(
|
| 18 |
+
img, max=0.3
|
| 19 |
+
)
|
| 20 |
+
return img_test
|
| 21 |
+
except Exception as err:
|
| 22 |
+
print(err)
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def inference_on_location(model, latitude = 2.98, longitude = 48.81, start_date=2020, end_date=2022):
|
| 27 |
+
"""Performe an inference on the latitude and longitude between the start date and the end date
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
latitude (float): the latitude of the landscape
|
| 31 |
+
longitude (float): the longitude of the landscape
|
| 32 |
+
start_date (str): the start date for our inference
|
| 33 |
+
end_date (str): the end date for our inference
|
| 34 |
+
model (_type_, optional): _description_. Defaults to model.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
|
| 38 |
+
"""
|
| 39 |
+
assert end_date > start_date, "end date must be stricly higher than start date"
|
| 40 |
+
location = [float(latitude), float(longitude)]
|
| 41 |
+
|
| 42 |
+
# Extract img numpy from earth engine and transform it to PIL img
|
| 43 |
+
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
|
| 44 |
+
while dates[-1] < datetime.datetime(end_date, 1, 1, 0, 0, 0):
|
| 45 |
+
dates.append(dates[-1] + relativedelta(months=1))
|
| 46 |
+
|
| 47 |
+
dates = [d.strftime("%Y-%m-%d") for d in dates]
|
| 48 |
+
|
| 49 |
+
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
|
| 50 |
+
all_image = [image for image in all_image if image is not None]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# tensorize & normalize img
|
| 56 |
+
preprocess = T.Compose(
|
| 57 |
+
[
|
| 58 |
+
T.ToPILImage(),
|
| 59 |
+
T.Resize((320, 320)),
|
| 60 |
+
# T.CenterCrop(224),
|
| 61 |
+
T.ToTensor(),
|
| 62 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Preprocess opened img
|
| 67 |
+
x = torch.stack([preprocess(imag) for imag in all_image]).cpu()
|
| 68 |
+
|
| 69 |
+
# launch inference on cpu
|
| 70 |
+
# x = torch.unsqueeze(x, dim=0).cpu()
|
| 71 |
+
model = model.cpu()
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
feats, code = model.net(x)
|
| 75 |
+
linear_pred = model.linear_probe(x, code)
|
| 76 |
+
linear_pred = linear_pred.argmax(1)
|
| 77 |
+
outputs = [{
|
| 78 |
+
"img": torch.unsqueeze(img, dim=0).detach().cpu(),
|
| 79 |
+
"linear_preds": torch.unsqueeze(linear_pred, dim=0).detach().cpu(),
|
| 80 |
+
} for img, linear_pred in zip(x, linear_pred)]
|
| 81 |
+
all_img = []
|
| 82 |
+
all_label = []
|
| 83 |
+
all_labeled_img = []
|
| 84 |
+
for output in outputs:
|
| 85 |
+
img, label, labeled_img = transform_to_pil(output)
|
| 86 |
+
all_img.append(img)
|
| 87 |
+
all_label.append(label)
|
| 88 |
+
all_labeled_img.append(labeled_img)
|
| 89 |
+
|
| 90 |
+
all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_labeled_img]
|
| 91 |
+
create_video(all_labeled_img, output_path='output/output.mp4')
|
| 92 |
+
|
| 93 |
+
# all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_img]
|
| 94 |
+
# create_video(all_labeled_img, output_path='raw.mp4')
|
| 95 |
+
|
| 96 |
+
return 'output.mp4'
|
| 97 |
+
|
| 98 |
+
def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'):
|
| 99 |
+
"""Performe an inference on the latitude and longitude between the start date and the end date
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
latitude (float): the latitude of the landscape
|
| 103 |
+
longitude (float): the longitude of the landscape
|
| 104 |
+
start_date (str): the start date for our inference
|
| 105 |
+
end_date (str): the end date for our inference
|
| 106 |
+
model (_type_, optional): _description_. Defaults to model.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
|
| 110 |
+
"""
|
| 111 |
+
location = [float(latitude), float(longitude)]
|
| 112 |
+
|
| 113 |
+
# Extract img numpy from earth engine and transform it to PIL img
|
| 114 |
+
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
|
| 115 |
+
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
|
| 116 |
+
img = extract_img(location, start_date, end_date)
|
| 117 |
+
img_test = transform_ee_img(
|
| 118 |
+
img, max=0.3
|
| 119 |
+
) # max value is the value from numpy file that will be equal to 255
|
| 120 |
+
|
| 121 |
+
# tensorize & normalize img
|
| 122 |
+
preprocess = T.Compose(
|
| 123 |
+
[
|
| 124 |
+
T.ToPILImage(),
|
| 125 |
+
T.Resize((320, 320)),
|
| 126 |
+
# T.CenterCrop(224),
|
| 127 |
+
T.ToTensor(),
|
| 128 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Preprocess opened img
|
| 133 |
+
x = preprocess(img_test)
|
| 134 |
+
|
| 135 |
+
# launch inference on cpu
|
| 136 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
| 137 |
+
model = model.cpu()
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
feats, code = model.net(x)
|
| 141 |
+
linear_pred = model.linear_probe(x, code)
|
| 142 |
+
linear_pred = linear_pred.argmax(1)
|
| 143 |
+
output = {
|
| 144 |
+
"img": x[: model.cfg.n_images].detach().cpu(),
|
| 145 |
+
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
|
| 146 |
+
}
|
| 147 |
+
nb_values = []
|
| 148 |
+
for i in range(7):
|
| 149 |
+
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
|
| 150 |
+
scores_init = [2,3,4,3,1,4,0]
|
| 151 |
+
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
|
| 152 |
+
|
| 153 |
+
img, label, labeled_img = transform_to_pil(output)
|
| 154 |
+
return img, labeled_img,score
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
import logging
|
| 159 |
+
import hydra
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
from model import LitUnsupervisedSegmenter
|
| 163 |
+
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
|
| 164 |
+
# Initialize hydra with configs
|
| 165 |
+
hydra.initialize(config_path="configs", job_name="corine")
|
| 166 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
| 167 |
+
logging.info(f"config : {cfg}")
|
| 168 |
+
# Load the model
|
| 169 |
+
|
| 170 |
+
nbclasses = cfg.dir_dataset_n_classes
|
| 171 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
| 172 |
+
logging.info(f"Model Initialiazed")
|
| 173 |
+
|
| 174 |
+
model_path = "checkpoint/model/model.pt"
|
| 175 |
+
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
| 176 |
+
logging.info(f"Model weights Loaded")
|
| 177 |
+
model.load_state_dict(saved_state_dict)
|
| 178 |
+
logging.info(f"Model Loaded")
|
| 179 |
+
inference_on_location(model)
|
biomap/inference.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.multiprocessing
|
| 2 |
+
import torchvision.transforms as T
|
| 3 |
+
from utils import transform_to_pil
|
| 4 |
+
|
| 5 |
+
def inference(image, model):
|
| 6 |
+
# tensorize & normalize img
|
| 7 |
+
preprocess = T.Compose(
|
| 8 |
+
[
|
| 9 |
+
T.ToPILImage(),
|
| 10 |
+
T.Resize((320, 320)),
|
| 11 |
+
# T.CenterCrop(224),
|
| 12 |
+
T.ToTensor(),
|
| 13 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 14 |
+
]
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Preprocess opened img
|
| 18 |
+
x = preprocess(image)
|
| 19 |
+
|
| 20 |
+
# launch inference on cpu
|
| 21 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
| 22 |
+
model = model.cpu()
|
| 23 |
+
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
feats, code = model.net(x)
|
| 26 |
+
linear_pred = model.linear_probe(x, code)
|
| 27 |
+
linear_pred = linear_pred.argmax(1)
|
| 28 |
+
output = {
|
| 29 |
+
"img": x[: model.cfg.n_images].detach().cpu(),
|
| 30 |
+
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
img, label, labeled_img = transform_to_pil(output)
|
| 34 |
+
return img, labeled_img, label
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
import hydra
|
| 39 |
+
from model import LitUnsupervisedSegmenter
|
| 40 |
+
from utils_gee import extract_img, transform_ee_img
|
| 41 |
+
latitude = 2.98
|
| 42 |
+
longitude = 48.81
|
| 43 |
+
start_date = '2020-03-20'
|
| 44 |
+
end_date = '2020-04-20'
|
| 45 |
+
|
| 46 |
+
location = [float(latitude), float(longitude)]
|
| 47 |
+
# Extract img numpy from earth engine and transform it to PIL img
|
| 48 |
+
img = extract_img(location, start_date, end_date)
|
| 49 |
+
image = transform_ee_img(
|
| 50 |
+
img, max=0.3
|
| 51 |
+
) # max value is the value from numpy file that will be equal to 255
|
| 52 |
+
print("image loaded")
|
| 53 |
+
# Initialize hydra with configs
|
| 54 |
+
hydra.initialize(config_path="configs", job_name="corine")
|
| 55 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
| 56 |
+
|
| 57 |
+
# Load the model
|
| 58 |
+
model_path = "checkpoint/model/model.pt"
|
| 59 |
+
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
| 60 |
+
|
| 61 |
+
nbclasses = cfg.dir_dataset_n_classes
|
| 62 |
+
|
| 63 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
| 64 |
+
print("model initialized")
|
| 65 |
+
model.load_state_dict(saved_state_dict)
|
| 66 |
+
print("model loaded")
|
| 67 |
+
# img.save("output/image.png")
|
| 68 |
+
img, labeled_img, label = inference(image, model)
|
| 69 |
+
img.save("output/img.png")
|
| 70 |
+
label.save("output/label.png")
|
| 71 |
+
labeled_img.save("output/labeled_img.png")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# def get_list_date(start_date, end_date):
|
| 77 |
+
# """Get all the date between the start date and the end date
|
| 78 |
+
|
| 79 |
+
# Args:
|
| 80 |
+
# start_date (str): start date at the format '%Y-%m-%d'
|
| 81 |
+
# end_date (str): end date at the format '%Y-%m-%d'
|
| 82 |
+
|
| 83 |
+
# Returns:
|
| 84 |
+
# list[str]: all the date between the start date and the end date
|
| 85 |
+
# """
|
| 86 |
+
# start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
|
| 87 |
+
# end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
|
| 88 |
+
# list_date = [start_date]
|
| 89 |
+
# date = start_date
|
| 90 |
+
# while date < end_date:
|
| 91 |
+
# date = date + datetime.timedelta(days=1)
|
| 92 |
+
# list_date.append(date)
|
| 93 |
+
# list_date.append(end_date)
|
| 94 |
+
# list_date2 = [x.strftime("%Y-%m-%d") for x in list_date]
|
| 95 |
+
# return list_date2
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# def get_length_interval(start_date, end_date):
|
| 99 |
+
# """Return how many days there is between the start date and the end date
|
| 100 |
+
|
| 101 |
+
# Args:
|
| 102 |
+
# start_date (str): start date at the format '%Y-%m-%d'
|
| 103 |
+
# end_date (str): end date at the format '%Y-%m-%d'
|
| 104 |
+
|
| 105 |
+
# Returns:
|
| 106 |
+
# int : number of days between start date and the end date
|
| 107 |
+
# """
|
| 108 |
+
# try:
|
| 109 |
+
# return len(get_list_date(start_date, end_date))
|
| 110 |
+
# except ValueError:
|
| 111 |
+
# return 0
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# def infer_unique_date(latitude, longitude, date, model=model):
|
| 115 |
+
# """Perform an inference on a latitude and a longitude at a specific date
|
| 116 |
+
|
| 117 |
+
# Args:
|
| 118 |
+
# latitude (float): the latitude of the landscape
|
| 119 |
+
# longitude (float): the longitude of the landscape
|
| 120 |
+
# date (str): date for the inference at the format '%Y-%m-%d'
|
| 121 |
+
# model (_type_, optional): _description_. Defaults to model.
|
| 122 |
+
|
| 123 |
+
# Returns:
|
| 124 |
+
# img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
|
| 125 |
+
# """
|
| 126 |
+
# start_date = date
|
| 127 |
+
# end_date = date
|
| 128 |
+
# location = [float(latitude), float(longitude)]
|
| 129 |
+
# # Extract img numpy from earth engine and transform it to PIL img
|
| 130 |
+
# img = extract_img(location, start_date, end_date)
|
| 131 |
+
# img_test = transform_ee_img(
|
| 132 |
+
# img, max=0.3
|
| 133 |
+
# ) # max value is the value from numpy file that will be equal to 255
|
| 134 |
+
|
| 135 |
+
# # tensorize & normalize img
|
| 136 |
+
# preprocess = T.Compose(
|
| 137 |
+
# [
|
| 138 |
+
# T.ToPILImage(),
|
| 139 |
+
# T.Resize((320, 320)),
|
| 140 |
+
# # T.CenterCrop(224),
|
| 141 |
+
# T.ToTensor(),
|
| 142 |
+
# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 143 |
+
# ]
|
| 144 |
+
# )
|
| 145 |
+
|
| 146 |
+
# # Preprocess opened img
|
| 147 |
+
# x = preprocess(img_test)
|
| 148 |
+
|
| 149 |
+
# # launch inference on cpu
|
| 150 |
+
# x = torch.unsqueeze(x, dim=0).cpu()
|
| 151 |
+
# model = model.cpu()
|
| 152 |
+
|
| 153 |
+
# with torch.no_grad():
|
| 154 |
+
# feats, code = model.net(x)
|
| 155 |
+
# linear_pred = model.linear_probe(x, code)
|
| 156 |
+
# linear_pred = linear_pred.argmax(1)
|
| 157 |
+
# output = {
|
| 158 |
+
# "img": x[: model.cfg.n_images].detach().cpu(),
|
| 159 |
+
# "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
|
| 160 |
+
# }
|
| 161 |
+
|
| 162 |
+
# img, label, labeled_img = transform_to_pil(output)
|
| 163 |
+
# biodiv_score = compute_biodiv_score(labeled_img)
|
| 164 |
+
# return img, labeled_img, biodiv_score
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# def get_img_array(start_date, end_date, latitude, longitude, model=model):
|
| 168 |
+
# list_date = get_list_date(start_date, end_date)
|
| 169 |
+
# list_img = []
|
| 170 |
+
# for date in list_date:
|
| 171 |
+
# list_img.append(img)
|
| 172 |
+
# return list_img
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# def variable_outputs(start_date, end_date, latitude, longitude, day, model=model):
|
| 176 |
+
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected
|
| 177 |
+
|
| 178 |
+
# Args:
|
| 179 |
+
# latitude (float): the latitude of the landscape
|
| 180 |
+
# longitude (float): the longitude of the landscape
|
| 181 |
+
# start_date (str): the start date for our inference
|
| 182 |
+
# end_date (str): the end date for our inference
|
| 183 |
+
# model (_type_, optional): _description_. Defaults to model.
|
| 184 |
+
|
| 185 |
+
# Returns:
|
| 186 |
+
# img,labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
|
| 187 |
+
# """
|
| 188 |
+
# list_date = get_list_date(start_date, end_date)
|
| 189 |
+
# k = int(day)
|
| 190 |
+
# date = list_date[k]
|
| 191 |
+
# img, labeled_img, biodiv_score = infer_unique_date(
|
| 192 |
+
# latitude, longitude, date, model=model
|
| 193 |
+
# )
|
| 194 |
+
# return img, labeled_img, biodiv_score
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# def variable_outputs2(
|
| 198 |
+
# start_date, end_date, latitude, longitude, day_number, model=model
|
| 199 |
+
# ):
|
| 200 |
+
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected
|
| 201 |
+
|
| 202 |
+
# Args:
|
| 203 |
+
# latitude (float): the latitude of the landscape
|
| 204 |
+
# longitude (float): the longitude of the landscape
|
| 205 |
+
# start_date (str): the start date for our inference
|
| 206 |
+
# end_date (str): the end date for our inference
|
| 207 |
+
# model (_type_, optional): _description_. Defaults to model.
|
| 208 |
+
|
| 209 |
+
# Returns:
|
| 210 |
+
# list[img,labeled_img,biodiv_score]: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
|
| 211 |
+
# """
|
| 212 |
+
# list_date = get_list_date(start_date, end_date)
|
| 213 |
+
# k = int(day_number)
|
| 214 |
+
# date = list_date[k]
|
| 215 |
+
# img, labeled_img, biodiv_score = infer_unique_date(
|
| 216 |
+
# latitude, longitude, date, model=model
|
| 217 |
+
# )
|
| 218 |
+
# return [img, labeled_img, biodiv_score]
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# def gif_maker(img_array):
|
| 222 |
+
# output_file = "test2.mkv"
|
| 223 |
+
# image_test = img_array[0]
|
| 224 |
+
# size = (320, 320)
|
| 225 |
+
# print(size)
|
| 226 |
+
# out = cv2.VideoWriter(
|
| 227 |
+
# output_file, cv2.VideoWriter_fourcc(*"avc1"), 15, frameSize=size
|
| 228 |
+
# )
|
| 229 |
+
# for i in range(len(img_array)):
|
| 230 |
+
# image = img_array[i]
|
| 231 |
+
# pix = np.array(image.getdata())
|
| 232 |
+
# out.write(pix)
|
| 233 |
+
# out.release()
|
| 234 |
+
# return output_file
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# def infer_multiple_date(start_date, end_date, latitude, longitude, model=model):
|
| 238 |
+
# """Perform an inference on all the dates between the start date and the end date at the latitude and longitude
|
| 239 |
+
|
| 240 |
+
# Args:
|
| 241 |
+
# latitude (float): the latitude of the landscape
|
| 242 |
+
# longitude (float): the longitude of the landscape
|
| 243 |
+
# start_date (str): the start date for our inference
|
| 244 |
+
# end_date (str): the end date for our inference
|
| 245 |
+
# model (_type_, optional): _description_. Defaults to model.
|
| 246 |
+
|
| 247 |
+
# Returns:
|
| 248 |
+
# list_img,list_labeled_img,list_biodiv_score: list of the original landscape, the labeled landscape and the biodiversity score and the landscape
|
| 249 |
+
# """
|
| 250 |
+
# list_date = get_list_date(start_date, end_date)
|
| 251 |
+
# list_img = []
|
| 252 |
+
# list_labeled_img = []
|
| 253 |
+
# list_biodiv_score = []
|
| 254 |
+
# for date in list_date:
|
| 255 |
+
# img, labeled_img, biodiv_score = infer_unique_date(
|
| 256 |
+
# latitude, longitude, date, model=model
|
| 257 |
+
# )
|
| 258 |
+
# list_img.append(img)
|
| 259 |
+
# list_labeled_img.append(labeled_img)
|
| 260 |
+
# list_biodiv_score.append(biodiv_score)
|
| 261 |
+
# return gif_maker(list_img), gif_maker(list_labeled_img), list_biodiv_score[0]
|
biomap/label.png
ADDED
|
biomap/model.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
from modules import *
|
| 3 |
+
from data import *
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
import torch.multiprocessing
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
import unet
|
| 9 |
+
|
| 10 |
+
class LitUnsupervisedSegmenter(pl.LightningModule):
|
| 11 |
+
def __init__(self, n_classes, cfg):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.cfg = cfg
|
| 14 |
+
self.n_classes = n_classes
|
| 15 |
+
|
| 16 |
+
if not cfg.continuous:
|
| 17 |
+
dim = n_classes
|
| 18 |
+
else:
|
| 19 |
+
dim = cfg.dim
|
| 20 |
+
|
| 21 |
+
data_dir = join(cfg.output_root, "data")
|
| 22 |
+
if cfg.arch == "feature-pyramid":
|
| 23 |
+
cut_model = load_model(cfg.model_type, data_dir).cuda()
|
| 24 |
+
self.net = FeaturePyramidNet(
|
| 25 |
+
cfg.granularity, cut_model, dim, cfg.continuous
|
| 26 |
+
)
|
| 27 |
+
elif cfg.arch == "dino":
|
| 28 |
+
self.net = DinoFeaturizer(dim, cfg)
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError("Unknown arch {}".format(cfg.arch))
|
| 31 |
+
|
| 32 |
+
self.train_cluster_probe = ClusterLookup(dim, n_classes)
|
| 33 |
+
|
| 34 |
+
self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters)
|
| 35 |
+
# self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1))
|
| 36 |
+
# self.linear_probe = nn.Sequential(OrderedDict([
|
| 37 |
+
# ('conv1', nn.Conv2d(dim, 2*n_classes, (7, 7), padding='same')),
|
| 38 |
+
# ('relu1', nn.ReLU()),
|
| 39 |
+
# ('conv2', nn.Conv2d(2*n_classes, n_classes, (3, 3), padding='same'))
|
| 40 |
+
# ]))
|
| 41 |
+
self.linear_probe = unet.AuxUNet(
|
| 42 |
+
enc_chs=(3, 32, 64, 128, 256),
|
| 43 |
+
dec_chs=(256, 128, 64, 32),
|
| 44 |
+
aux_ch=70,
|
| 45 |
+
num_class=n_classes,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1))
|
| 49 |
+
|
| 50 |
+
self.cluster_metrics = UnsupervisedMetrics(
|
| 51 |
+
"test/cluster/", n_classes, cfg.extra_clusters, True
|
| 52 |
+
)
|
| 53 |
+
self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False)
|
| 54 |
+
|
| 55 |
+
self.test_cluster_metrics = UnsupervisedMetrics(
|
| 56 |
+
"final/cluster/", n_classes, cfg.extra_clusters, True
|
| 57 |
+
)
|
| 58 |
+
self.test_linear_metrics = UnsupervisedMetrics(
|
| 59 |
+
"final/linear/", n_classes, 0, False
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss()
|
| 63 |
+
self.crf_loss_fn = ContrastiveCRFLoss(
|
| 64 |
+
cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg)
|
| 68 |
+
for p in self.contrastive_corr_loss_fn.parameters():
|
| 69 |
+
p.requires_grad = False
|
| 70 |
+
|
| 71 |
+
self.automatic_optimization = False
|
| 72 |
+
|
| 73 |
+
if self.cfg.dataset_name.startswith("cityscapes"):
|
| 74 |
+
self.label_cmap = create_cityscapes_colormap()
|
| 75 |
+
else:
|
| 76 |
+
self.label_cmap = create_pascal_label_colormap()
|
| 77 |
+
|
| 78 |
+
self.val_steps = 0
|
| 79 |
+
self.save_hyperparameters()
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
# in lightning, forward defines the prediction/inference actions
|
| 83 |
+
return self.net(x)[1]
|
| 84 |
+
|
| 85 |
+
def training_step(self, batch, batch_idx):
|
| 86 |
+
# training_step defined the train loop.
|
| 87 |
+
# It is independent of forward
|
| 88 |
+
net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers()
|
| 89 |
+
|
| 90 |
+
net_optim.zero_grad()
|
| 91 |
+
linear_probe_optim.zero_grad()
|
| 92 |
+
cluster_probe_optim.zero_grad()
|
| 93 |
+
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
ind = batch["ind"]
|
| 96 |
+
img = batch["img"]
|
| 97 |
+
img_aug = batch["img_aug"]
|
| 98 |
+
coord_aug = batch["coord_aug"]
|
| 99 |
+
img_pos = batch["img_pos"]
|
| 100 |
+
label = batch["label"]
|
| 101 |
+
label_pos = batch["label_pos"]
|
| 102 |
+
|
| 103 |
+
feats, code = self.net(img)
|
| 104 |
+
if self.cfg.correspondence_weight > 0:
|
| 105 |
+
feats_pos, code_pos = self.net(img_pos)
|
| 106 |
+
log_args = dict(sync_dist=False, rank_zero_only=True)
|
| 107 |
+
|
| 108 |
+
if self.cfg.use_true_labels:
|
| 109 |
+
signal = one_hot_feats(label + 1, self.n_classes + 1)
|
| 110 |
+
signal_pos = one_hot_feats(label_pos + 1, self.n_classes + 1)
|
| 111 |
+
else:
|
| 112 |
+
signal = feats
|
| 113 |
+
signal_pos = feats_pos
|
| 114 |
+
|
| 115 |
+
loss = 0
|
| 116 |
+
|
| 117 |
+
should_log_hist = (
|
| 118 |
+
(self.cfg.hist_freq is not None)
|
| 119 |
+
and (self.global_step % self.cfg.hist_freq == 0)
|
| 120 |
+
and (self.global_step > 0)
|
| 121 |
+
)
|
| 122 |
+
if self.cfg.use_salience:
|
| 123 |
+
salience = batch["mask"].to(torch.float32).squeeze(1)
|
| 124 |
+
salience_pos = batch["mask_pos"].to(torch.float32).squeeze(1)
|
| 125 |
+
else:
|
| 126 |
+
salience = None
|
| 127 |
+
salience_pos = None
|
| 128 |
+
|
| 129 |
+
if self.cfg.correspondence_weight > 0:
|
| 130 |
+
(
|
| 131 |
+
pos_intra_loss,
|
| 132 |
+
pos_intra_cd,
|
| 133 |
+
pos_inter_loss,
|
| 134 |
+
pos_inter_cd,
|
| 135 |
+
neg_inter_loss,
|
| 136 |
+
neg_inter_cd,
|
| 137 |
+
) = self.contrastive_corr_loss_fn(
|
| 138 |
+
signal,
|
| 139 |
+
signal_pos,
|
| 140 |
+
salience,
|
| 141 |
+
salience_pos,
|
| 142 |
+
code,
|
| 143 |
+
code_pos,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if should_log_hist:
|
| 147 |
+
self.logger.experiment.add_histogram(
|
| 148 |
+
"intra_cd", pos_intra_cd, self.global_step
|
| 149 |
+
)
|
| 150 |
+
self.logger.experiment.add_histogram(
|
| 151 |
+
"inter_cd", pos_inter_cd, self.global_step
|
| 152 |
+
)
|
| 153 |
+
self.logger.experiment.add_histogram(
|
| 154 |
+
"neg_cd", neg_inter_cd, self.global_step
|
| 155 |
+
)
|
| 156 |
+
neg_inter_loss = neg_inter_loss.mean()
|
| 157 |
+
pos_intra_loss = pos_intra_loss.mean()
|
| 158 |
+
pos_inter_loss = pos_inter_loss.mean()
|
| 159 |
+
self.log("loss/pos_intra", pos_intra_loss, **log_args)
|
| 160 |
+
self.log("loss/pos_inter", pos_inter_loss, **log_args)
|
| 161 |
+
self.log("loss/neg_inter", neg_inter_loss, **log_args)
|
| 162 |
+
self.log("cd/pos_intra", pos_intra_cd.mean(), **log_args)
|
| 163 |
+
self.log("cd/pos_inter", pos_inter_cd.mean(), **log_args)
|
| 164 |
+
self.log("cd/neg_inter", neg_inter_cd.mean(), **log_args)
|
| 165 |
+
|
| 166 |
+
loss += (
|
| 167 |
+
self.cfg.pos_inter_weight * pos_inter_loss
|
| 168 |
+
+ self.cfg.pos_intra_weight * pos_intra_loss
|
| 169 |
+
+ self.cfg.neg_inter_weight * neg_inter_loss
|
| 170 |
+
) * self.cfg.correspondence_weight
|
| 171 |
+
|
| 172 |
+
if self.cfg.rec_weight > 0:
|
| 173 |
+
rec_feats = self.decoder(code)
|
| 174 |
+
rec_loss = -(norm(rec_feats) * norm(feats)).sum(1).mean()
|
| 175 |
+
self.log("loss/rec", rec_loss, **log_args)
|
| 176 |
+
loss += self.cfg.rec_weight * rec_loss
|
| 177 |
+
|
| 178 |
+
if self.cfg.aug_alignment_weight > 0:
|
| 179 |
+
orig_feats_aug, orig_code_aug = self.net(img_aug)
|
| 180 |
+
downsampled_coord_aug = resize(
|
| 181 |
+
coord_aug.permute(0, 3, 1, 2), orig_code_aug.shape[2]
|
| 182 |
+
).permute(0, 2, 3, 1)
|
| 183 |
+
aug_alignment = -torch.einsum(
|
| 184 |
+
"bkhw,bkhw->bhw",
|
| 185 |
+
norm(sample(code, downsampled_coord_aug)),
|
| 186 |
+
norm(orig_code_aug),
|
| 187 |
+
).mean()
|
| 188 |
+
self.log("loss/aug_alignment", aug_alignment, **log_args)
|
| 189 |
+
loss += self.cfg.aug_alignment_weight * aug_alignment
|
| 190 |
+
|
| 191 |
+
if self.cfg.crf_weight > 0:
|
| 192 |
+
crf = self.crf_loss_fn(resize(img, 56), norm(resize(code, 56))).mean()
|
| 193 |
+
self.log("loss/crf", crf, **log_args)
|
| 194 |
+
loss += self.cfg.crf_weight * crf
|
| 195 |
+
|
| 196 |
+
flat_label = label.reshape(-1)
|
| 197 |
+
mask = (flat_label >= 0) & (flat_label < self.n_classes)
|
| 198 |
+
|
| 199 |
+
detached_code = torch.clone(code.detach())
|
| 200 |
+
|
| 201 |
+
# pdb.set_trace()
|
| 202 |
+
|
| 203 |
+
linear_logits = self.linear_probe(img, detached_code)
|
| 204 |
+
linear_logits = F.interpolate(
|
| 205 |
+
linear_logits, label.shape[-2:], mode="bilinear", align_corners=False
|
| 206 |
+
)
|
| 207 |
+
linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes)
|
| 208 |
+
linear_loss = self.linear_probe_loss_fn(
|
| 209 |
+
linear_logits[mask], flat_label[mask]
|
| 210 |
+
).mean()
|
| 211 |
+
loss += linear_loss
|
| 212 |
+
self.log("loss/linear", linear_loss, **log_args)
|
| 213 |
+
|
| 214 |
+
cluster_loss, cluster_probs = self.cluster_probe(detached_code, None)
|
| 215 |
+
loss += cluster_loss
|
| 216 |
+
self.log("loss/cluster", cluster_loss, **log_args)
|
| 217 |
+
self.log("loss/total", loss, **log_args)
|
| 218 |
+
|
| 219 |
+
self.manual_backward(loss)
|
| 220 |
+
net_optim.step()
|
| 221 |
+
cluster_probe_optim.step()
|
| 222 |
+
linear_probe_optim.step()
|
| 223 |
+
|
| 224 |
+
if (
|
| 225 |
+
self.cfg.reset_probe_steps is not None
|
| 226 |
+
and self.global_step == self.cfg.reset_probe_steps
|
| 227 |
+
):
|
| 228 |
+
print("RESETTING PROBES")
|
| 229 |
+
self.linear_probe.reset_parameters()
|
| 230 |
+
self.cluster_probe.reset_parameters()
|
| 231 |
+
self.trainer.optimizers[1] = torch.optim.Adam(
|
| 232 |
+
list(self.linear_probe.parameters()), lr=5e-3
|
| 233 |
+
)
|
| 234 |
+
self.trainer.optimizers[2] = torch.optim.Adam(
|
| 235 |
+
list(self.cluster_probe.parameters()), lr=5e-3
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if self.global_step % 2000 == 0 and self.global_step > 0:
|
| 239 |
+
print("RESETTING TFEVENT FILE")
|
| 240 |
+
# Make a new tfevent file
|
| 241 |
+
self.logger.experiment.close()
|
| 242 |
+
self.logger.experiment._get_file_writer()
|
| 243 |
+
|
| 244 |
+
return loss
|
| 245 |
+
|
| 246 |
+
def on_train_start(self):
|
| 247 |
+
tb_metrics = {**self.linear_metrics.compute(), **self.cluster_metrics.compute()}
|
| 248 |
+
self.logger.log_hyperparams(self.cfg, tb_metrics)
|
| 249 |
+
|
| 250 |
+
def validation_step(self, batch, batch_idx):
|
| 251 |
+
img = batch["img"]
|
| 252 |
+
label = batch["label"]
|
| 253 |
+
self.net.eval()
|
| 254 |
+
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
feats, code = self.net(img)
|
| 257 |
+
|
| 258 |
+
# code = F.interpolate(code, label.shape[-2:], mode='bilinear', align_corners=False)
|
| 259 |
+
# linear_preds = self.linear_probe(code)
|
| 260 |
+
linear_preds = self.linear_probe(img, code)
|
| 261 |
+
linear_preds = linear_preds.argmax(1)
|
| 262 |
+
self.linear_metrics.update(linear_preds, label)
|
| 263 |
+
|
| 264 |
+
code = F.interpolate(
|
| 265 |
+
code, label.shape[-2:], mode="bilinear", align_corners=False
|
| 266 |
+
)
|
| 267 |
+
cluster_loss, cluster_preds = self.cluster_probe(code, None)
|
| 268 |
+
cluster_preds = cluster_preds.argmax(1)
|
| 269 |
+
self.cluster_metrics.update(cluster_preds, label)
|
| 270 |
+
|
| 271 |
+
return {
|
| 272 |
+
"img": img[: self.cfg.n_images].detach().cpu(),
|
| 273 |
+
"linear_preds": linear_preds[: self.cfg.n_images].detach().cpu(),
|
| 274 |
+
"cluster_preds": cluster_preds[: self.cfg.n_images].detach().cpu(),
|
| 275 |
+
"label": label[: self.cfg.n_images].detach().cpu(),
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
def validation_epoch_end(self, outputs) -> None:
|
| 279 |
+
super().validation_epoch_end(outputs)
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
tb_metrics = {
|
| 282 |
+
**self.linear_metrics.compute(),
|
| 283 |
+
**self.cluster_metrics.compute(),
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
if self.trainer.is_global_zero and not self.cfg.submitting_to_aml:
|
| 287 |
+
# output_num = 0
|
| 288 |
+
output_num = random.randint(0, len(outputs) - 1)
|
| 289 |
+
output = {k: v.detach().cpu() for k, v in outputs[output_num].items()}
|
| 290 |
+
|
| 291 |
+
# pdb.set_trace()
|
| 292 |
+
alpha = 0.4
|
| 293 |
+
n_rows = 6
|
| 294 |
+
fig, ax = plt.subplots(
|
| 295 |
+
n_rows,
|
| 296 |
+
self.cfg.n_images,
|
| 297 |
+
figsize=(self.cfg.n_images * 3, n_rows * 3),
|
| 298 |
+
)
|
| 299 |
+
for i in range(self.cfg.n_images):
|
| 300 |
+
try:
|
| 301 |
+
rbg_img = prep_for_plot(output["img"][i])
|
| 302 |
+
true_label = output["label"].squeeze()[i]
|
| 303 |
+
true_label[true_label == -1] = 7
|
| 304 |
+
except:
|
| 305 |
+
continue
|
| 306 |
+
# ax[0, i].imshow(prep_for_plot(output["img"][i]))
|
| 307 |
+
# ax[1, i].imshow(self.label_cmap[output["label"].squeeze()[i]])
|
| 308 |
+
# ax[2, i].imshow(self.label_cmap[output["linear_preds"][i]])
|
| 309 |
+
# ax[3, i].imshow(self.label_cmap[self.cluster_metrics.map_clusters(output["cluster_preds"][i])])
|
| 310 |
+
ax[0, i].imshow(rbg_img)
|
| 311 |
+
|
| 312 |
+
ax[1, i].imshow(rbg_img)
|
| 313 |
+
ax[1, i].imshow(true_label, alpha=alpha, cmap=cmap, norm=norm)
|
| 314 |
+
|
| 315 |
+
ax[2, i].imshow(rbg_img)
|
| 316 |
+
pred_label = output["linear_preds"][i]
|
| 317 |
+
ax[2, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
|
| 318 |
+
|
| 319 |
+
ax[3, i].imshow(rbg_img)
|
| 320 |
+
retouched_label = retouch_label(pred_label.numpy(), true_label)
|
| 321 |
+
ax[3, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
|
| 322 |
+
|
| 323 |
+
ax[4, i].imshow(rbg_img)
|
| 324 |
+
pred_label = self.cluster_metrics.map_clusters(
|
| 325 |
+
output["cluster_preds"][i]
|
| 326 |
+
)
|
| 327 |
+
ax[4, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
|
| 328 |
+
# ax[3, i].imshow(map_clusters_with_label(true_label, pred_label), alpha=0.5, cmap=cmap, norm=norm)
|
| 329 |
+
|
| 330 |
+
ax[5, i].imshow(rbg_img)
|
| 331 |
+
retouched_label = retouch_label(pred_label.numpy(), true_label)
|
| 332 |
+
ax[5, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
|
| 333 |
+
|
| 334 |
+
ax[0, 0].set_ylabel("Image", fontsize=16)
|
| 335 |
+
ax[1, 0].set_ylabel("Label", fontsize=16)
|
| 336 |
+
ax[2, 0].set_ylabel("UNet Probe", fontsize=16)
|
| 337 |
+
ax[3, 0].set_ylabel("Retouched UNet Probe", fontsize=16)
|
| 338 |
+
ax[4, 0].set_ylabel("Cluster Probe", fontsize=16)
|
| 339 |
+
ax[5, 0].set_ylabel("Retouched cluster Probe", fontsize=16)
|
| 340 |
+
remove_axes(ax)
|
| 341 |
+
plt.tight_layout()
|
| 342 |
+
add_plot(self.logger.experiment, "plot_labels", self.global_step)
|
| 343 |
+
|
| 344 |
+
if self.cfg.has_labels:
|
| 345 |
+
fig = plt.figure(figsize=(13, 10))
|
| 346 |
+
ax = fig.gca()
|
| 347 |
+
hist = (
|
| 348 |
+
self.cluster_metrics.histogram.detach().cpu().to(torch.float32)
|
| 349 |
+
)
|
| 350 |
+
hist /= torch.clamp_min(hist.sum(dim=0, keepdim=True), 1)
|
| 351 |
+
sns.heatmap(hist.t(), annot=False, fmt="g", ax=ax, cmap="Blues")
|
| 352 |
+
ax.set_xlabel("Predicted labels")
|
| 353 |
+
ax.set_ylabel("True labels")
|
| 354 |
+
names = get_class_labels(self.cfg.dataset_name)
|
| 355 |
+
if self.cfg.extra_clusters:
|
| 356 |
+
names = names + ["Extra"]
|
| 357 |
+
ax.set_xticks(np.arange(0, len(names)) + 0.5)
|
| 358 |
+
ax.set_yticks(np.arange(0, len(names)) + 0.5)
|
| 359 |
+
ax.xaxis.tick_top()
|
| 360 |
+
ax.xaxis.set_ticklabels(names, fontsize=14)
|
| 361 |
+
ax.yaxis.set_ticklabels(names, fontsize=14)
|
| 362 |
+
colors = [self.label_cmap[i] / 255.0 for i in range(len(names))]
|
| 363 |
+
[
|
| 364 |
+
t.set_color(colors[i])
|
| 365 |
+
for i, t in enumerate(ax.xaxis.get_ticklabels())
|
| 366 |
+
]
|
| 367 |
+
[
|
| 368 |
+
t.set_color(colors[i])
|
| 369 |
+
for i, t in enumerate(ax.yaxis.get_ticklabels())
|
| 370 |
+
]
|
| 371 |
+
# ax.yaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
|
| 372 |
+
# ax.xaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
|
| 373 |
+
plt.xticks(rotation=90)
|
| 374 |
+
plt.yticks(rotation=0)
|
| 375 |
+
ax.vlines(
|
| 376 |
+
np.arange(0, len(names) + 1),
|
| 377 |
+
color=[0.5, 0.5, 0.5],
|
| 378 |
+
*ax.get_xlim()
|
| 379 |
+
)
|
| 380 |
+
ax.hlines(
|
| 381 |
+
np.arange(0, len(names) + 1),
|
| 382 |
+
color=[0.5, 0.5, 0.5],
|
| 383 |
+
*ax.get_ylim()
|
| 384 |
+
)
|
| 385 |
+
plt.tight_layout()
|
| 386 |
+
add_plot(self.logger.experiment, "conf_matrix", self.global_step)
|
| 387 |
+
|
| 388 |
+
all_bars = torch.cat(
|
| 389 |
+
[
|
| 390 |
+
self.cluster_metrics.histogram.sum(0).cpu(),
|
| 391 |
+
self.cluster_metrics.histogram.sum(1).cpu(),
|
| 392 |
+
],
|
| 393 |
+
axis=0,
|
| 394 |
+
)
|
| 395 |
+
ymin = max(all_bars.min() * 0.8, 1)
|
| 396 |
+
ymax = all_bars.max() * 1.2
|
| 397 |
+
|
| 398 |
+
fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 1 * 4))
|
| 399 |
+
ax[0].bar(
|
| 400 |
+
range(self.n_classes + self.cfg.extra_clusters),
|
| 401 |
+
self.cluster_metrics.histogram.sum(0).cpu(),
|
| 402 |
+
tick_label=names,
|
| 403 |
+
color=colors,
|
| 404 |
+
)
|
| 405 |
+
ax[0].set_ylim(ymin, ymax)
|
| 406 |
+
ax[0].set_title("Label Frequency")
|
| 407 |
+
ax[0].set_yscale("log")
|
| 408 |
+
ax[0].tick_params(axis="x", labelrotation=90)
|
| 409 |
+
|
| 410 |
+
ax[1].bar(
|
| 411 |
+
range(self.n_classes + self.cfg.extra_clusters),
|
| 412 |
+
self.cluster_metrics.histogram.sum(1).cpu(),
|
| 413 |
+
tick_label=names,
|
| 414 |
+
color=colors,
|
| 415 |
+
)
|
| 416 |
+
ax[1].set_ylim(ymin, ymax)
|
| 417 |
+
ax[1].set_title("Cluster Frequency")
|
| 418 |
+
ax[1].set_yscale("log")
|
| 419 |
+
ax[1].tick_params(axis="x", labelrotation=90)
|
| 420 |
+
|
| 421 |
+
plt.tight_layout()
|
| 422 |
+
add_plot(
|
| 423 |
+
self.logger.experiment, "label frequency", self.global_step
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
if self.global_step > 2:
|
| 427 |
+
self.log_dict(tb_metrics)
|
| 428 |
+
|
| 429 |
+
if self.trainer.is_global_zero and self.cfg.azureml_logging:
|
| 430 |
+
from azureml.core.run import Run
|
| 431 |
+
|
| 432 |
+
run_logger = Run.get_context()
|
| 433 |
+
for metric, value in tb_metrics.items():
|
| 434 |
+
run_logger.log(metric, value)
|
| 435 |
+
|
| 436 |
+
self.linear_metrics.reset()
|
| 437 |
+
self.cluster_metrics.reset()
|
| 438 |
+
|
| 439 |
+
def configure_optimizers(self):
|
| 440 |
+
main_params = list(self.net.parameters())
|
| 441 |
+
|
| 442 |
+
if self.cfg.rec_weight > 0:
|
| 443 |
+
main_params.extend(self.decoder.parameters())
|
| 444 |
+
|
| 445 |
+
net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr)
|
| 446 |
+
linear_probe_optim = torch.optim.Adam(
|
| 447 |
+
list(self.linear_probe.parameters()), lr=5e-3
|
| 448 |
+
)
|
| 449 |
+
cluster_probe_optim = torch.optim.Adam(
|
| 450 |
+
list(self.cluster_probe.parameters()), lr=5e-3
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
return net_optim, linear_probe_optim, cluster_probe_optim
|
biomap/modules.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from utils import *
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import dino.vision_transformer as vits
|
| 6 |
+
|
| 7 |
+
import pdb
|
| 8 |
+
|
| 9 |
+
class LambdaLayer(nn.Module):
|
| 10 |
+
def __init__(self, lambd):
|
| 11 |
+
super(LambdaLayer, self).__init__()
|
| 12 |
+
self.lambd = lambd
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return self.lambd(x)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DinoFeaturizer(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(self, dim, cfg):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.cfg = cfg
|
| 23 |
+
self.dim = dim
|
| 24 |
+
patch_size = self.cfg.dino_patch_size
|
| 25 |
+
self.patch_size = patch_size
|
| 26 |
+
self.feat_type = self.cfg.dino_feat_type
|
| 27 |
+
arch = self.cfg.model_type
|
| 28 |
+
self.model = vits.__dict__[arch](
|
| 29 |
+
patch_size=patch_size,
|
| 30 |
+
num_classes=0)
|
| 31 |
+
for p in self.model.parameters():
|
| 32 |
+
p.requires_grad = False
|
| 33 |
+
# pdb.set_trace()
|
| 34 |
+
self.model=self.model.cpu()
|
| 35 |
+
self.model.eval()
|
| 36 |
+
self.dropout = torch.nn.Dropout2d(p=.1)
|
| 37 |
+
|
| 38 |
+
if arch == "vit_small" and patch_size == 16:
|
| 39 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
| 40 |
+
elif arch == "vit_small" and patch_size == 8:
|
| 41 |
+
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
|
| 42 |
+
elif arch == "vit_base" and patch_size == 16:
|
| 43 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
| 44 |
+
elif arch == "vit_base" and patch_size == 8:
|
| 45 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError("Unknown arch and patch size")
|
| 48 |
+
|
| 49 |
+
if cfg.pretrained_weights is not None:
|
| 50 |
+
state_dict = torch.load(cfg.pretrained_weights, map_location="cpu")
|
| 51 |
+
state_dict = state_dict["teacher"]
|
| 52 |
+
# remove `module.` prefix
|
| 53 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 54 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
| 55 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 56 |
+
|
| 57 |
+
# state_dict = {k.replace("projection_head", "mlp"): v for k, v in state_dict.items()}
|
| 58 |
+
# state_dict = {k.replace("prototypes", "last_layer"): v for k, v in state_dict.items()}
|
| 59 |
+
|
| 60 |
+
msg = self.model.load_state_dict(state_dict, strict=False)
|
| 61 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(cfg.pretrained_weights, msg))
|
| 62 |
+
else:
|
| 63 |
+
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
| 64 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
| 65 |
+
self.model.load_state_dict(state_dict, strict=True)
|
| 66 |
+
|
| 67 |
+
if arch == "vit_small":
|
| 68 |
+
self.n_feats = 384
|
| 69 |
+
else:
|
| 70 |
+
self.n_feats = 768
|
| 71 |
+
self.cluster1 = self.make_clusterer(self.n_feats)
|
| 72 |
+
self.proj_type = cfg.projection_type
|
| 73 |
+
if self.proj_type == "nonlinear":
|
| 74 |
+
self.cluster2 = self.make_nonlinear_clusterer(self.n_feats)
|
| 75 |
+
|
| 76 |
+
def make_clusterer(self, in_channels):
|
| 77 |
+
return torch.nn.Sequential(
|
| 78 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # ,
|
| 79 |
+
|
| 80 |
+
def make_nonlinear_clusterer(self, in_channels):
|
| 81 |
+
return torch.nn.Sequential(
|
| 82 |
+
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
|
| 83 |
+
torch.nn.ReLU(),
|
| 84 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1)))
|
| 85 |
+
|
| 86 |
+
def forward(self, img, n=1, return_class_feat=False):
|
| 87 |
+
self.model.eval()
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
assert (img.shape[2] % self.patch_size == 0)
|
| 90 |
+
assert (img.shape[3] % self.patch_size == 0)
|
| 91 |
+
|
| 92 |
+
# get selected layer activations
|
| 93 |
+
feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
|
| 94 |
+
feat, attn, qkv = feat[0], attn[0], qkv[0]
|
| 95 |
+
|
| 96 |
+
feat_h = img.shape[2] // self.patch_size
|
| 97 |
+
feat_w = img.shape[3] // self.patch_size
|
| 98 |
+
|
| 99 |
+
if self.feat_type == "feat":
|
| 100 |
+
image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
|
| 101 |
+
elif self.feat_type == "KK":
|
| 102 |
+
image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1)
|
| 103 |
+
B, H, I, J, D = image_k.shape
|
| 104 |
+
image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J)
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError("Unknown feat type:{}".format(self.feat_type))
|
| 107 |
+
|
| 108 |
+
if return_class_feat:
|
| 109 |
+
return feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2)
|
| 110 |
+
|
| 111 |
+
if self.proj_type is not None:
|
| 112 |
+
code = self.cluster1(self.dropout(image_feat))
|
| 113 |
+
if self.proj_type == "nonlinear":
|
| 114 |
+
code += self.cluster2(self.dropout(image_feat))
|
| 115 |
+
else:
|
| 116 |
+
code = image_feat
|
| 117 |
+
|
| 118 |
+
if self.cfg.dropout:
|
| 119 |
+
return self.dropout(image_feat), code
|
| 120 |
+
else:
|
| 121 |
+
return image_feat, code
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ResizeAndClassify(nn.Module):
|
| 125 |
+
|
| 126 |
+
def __init__(self, dim: int, size: int, n_classes: int):
|
| 127 |
+
super(ResizeAndClassify, self).__init__()
|
| 128 |
+
self.size = size
|
| 129 |
+
self.predictor = torch.nn.Sequential(
|
| 130 |
+
torch.nn.Conv2d(dim, n_classes, (1, 1)),
|
| 131 |
+
torch.nn.LogSoftmax(1))
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
return F.interpolate(self.predictor.forward(x), self.size, mode="bilinear", align_corners=False)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class ClusterLookup(nn.Module):
|
| 138 |
+
|
| 139 |
+
def __init__(self, dim: int, n_classes: int):
|
| 140 |
+
super(ClusterLookup, self).__init__()
|
| 141 |
+
self.n_classes = n_classes
|
| 142 |
+
self.dim = dim
|
| 143 |
+
self.clusters = torch.nn.Parameter(torch.randn(n_classes, dim))
|
| 144 |
+
|
| 145 |
+
def reset_parameters(self):
|
| 146 |
+
with torch.no_grad():
|
| 147 |
+
self.clusters.copy_(torch.randn(self.n_classes, self.dim))
|
| 148 |
+
|
| 149 |
+
def forward(self, x, alpha, log_probs=False):
|
| 150 |
+
normed_clusters = F.normalize(self.clusters, dim=1)
|
| 151 |
+
normed_features = F.normalize(x, dim=1)
|
| 152 |
+
inner_products = torch.einsum("bchw,nc->bnhw", normed_features, normed_clusters)
|
| 153 |
+
|
| 154 |
+
if alpha is None:
|
| 155 |
+
cluster_probs = F.one_hot(torch.argmax(inner_products, dim=1), self.clusters.shape[0]) \
|
| 156 |
+
.permute(0, 3, 1, 2).to(torch.float32)
|
| 157 |
+
else:
|
| 158 |
+
cluster_probs = nn.functional.softmax(inner_products * alpha, dim=1)
|
| 159 |
+
|
| 160 |
+
cluster_loss = -(cluster_probs * inner_products).sum(1).mean()
|
| 161 |
+
if log_probs:
|
| 162 |
+
return nn.functional.log_softmax(inner_products * alpha, dim=1)
|
| 163 |
+
else:
|
| 164 |
+
return cluster_loss, cluster_probs
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class FeaturePyramidNet(nn.Module):
|
| 168 |
+
|
| 169 |
+
@staticmethod
|
| 170 |
+
def _helper(x):
|
| 171 |
+
# TODO remove this hard coded 56
|
| 172 |
+
return F.interpolate(x, 56, mode="bilinear", align_corners=False).unsqueeze(-1)
|
| 173 |
+
|
| 174 |
+
def make_clusterer(self, in_channels):
|
| 175 |
+
return torch.nn.Sequential(
|
| 176 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
|
| 177 |
+
LambdaLayer(FeaturePyramidNet._helper))
|
| 178 |
+
|
| 179 |
+
def make_nonlinear_clusterer(self, in_channels):
|
| 180 |
+
return torch.nn.Sequential(
|
| 181 |
+
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
|
| 182 |
+
torch.nn.ReLU(),
|
| 183 |
+
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
|
| 184 |
+
torch.nn.ReLU(),
|
| 185 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
|
| 186 |
+
LambdaLayer(FeaturePyramidNet._helper))
|
| 187 |
+
|
| 188 |
+
def __init__(self, granularity, cut_model, dim, continuous):
|
| 189 |
+
super(FeaturePyramidNet, self).__init__()
|
| 190 |
+
self.layer_nums = [5, 6, 7]
|
| 191 |
+
self.spatial_resolutions = [7, 14, 28, 56]
|
| 192 |
+
self.feat_channels = [2048, 1024, 512, 3]
|
| 193 |
+
self.extra_channels = [128, 64, 32, 32]
|
| 194 |
+
self.granularity = granularity
|
| 195 |
+
self.encoder = NetWithActivations(cut_model, self.layer_nums)
|
| 196 |
+
self.dim = dim
|
| 197 |
+
self.continuous = continuous
|
| 198 |
+
self.n_feats = self.dim
|
| 199 |
+
|
| 200 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
| 201 |
+
|
| 202 |
+
assert granularity in {1, 2, 3, 4}
|
| 203 |
+
self.cluster1 = self.make_clusterer(self.feat_channels[0])
|
| 204 |
+
self.cluster1_nl = self.make_nonlinear_clusterer(self.feat_channels[0])
|
| 205 |
+
|
| 206 |
+
if granularity >= 2:
|
| 207 |
+
# self.conv1 = DoubleConv(self.feat_channels[0], self.extra_channels[0])
|
| 208 |
+
# self.conv2 = DoubleConv(self.extra_channels[0] + self.feat_channels[1], self.extra_channels[1])
|
| 209 |
+
self.conv2 = DoubleConv(self.feat_channels[0] + self.feat_channels[1], self.extra_channels[1])
|
| 210 |
+
self.cluster2 = self.make_clusterer(self.extra_channels[1])
|
| 211 |
+
if granularity >= 3:
|
| 212 |
+
self.conv3 = DoubleConv(self.extra_channels[1] + self.feat_channels[2], self.extra_channels[2])
|
| 213 |
+
self.cluster3 = self.make_clusterer(self.extra_channels[2])
|
| 214 |
+
if granularity >= 4:
|
| 215 |
+
self.conv4 = DoubleConv(self.extra_channels[2] + self.feat_channels[3], self.extra_channels[3])
|
| 216 |
+
self.cluster4 = self.make_clusterer(self.extra_channels[3])
|
| 217 |
+
|
| 218 |
+
def c(self, x, y):
|
| 219 |
+
return torch.cat([x, y], dim=1)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
feats = self.encoder(x)
|
| 224 |
+
low_res_feats = feats[self.layer_nums[-1]]
|
| 225 |
+
|
| 226 |
+
all_clusters = []
|
| 227 |
+
|
| 228 |
+
# all_clusters.append(self.cluster1(low_res_feats) + self.cluster1_nl(low_res_feats))
|
| 229 |
+
all_clusters.append(self.cluster1(low_res_feats))
|
| 230 |
+
|
| 231 |
+
if self.granularity >= 2:
|
| 232 |
+
# f1 = self.conv1(low_res_feats)
|
| 233 |
+
# f1_up = self.up(f1)
|
| 234 |
+
f1_up = self.up(low_res_feats)
|
| 235 |
+
f2 = self.conv2(self.c(f1_up, feats[self.layer_nums[-2]]))
|
| 236 |
+
all_clusters.append(self.cluster2(f2))
|
| 237 |
+
if self.granularity >= 3:
|
| 238 |
+
f2_up = self.up(f2)
|
| 239 |
+
f3 = self.conv3(self.c(f2_up, feats[self.layer_nums[-3]]))
|
| 240 |
+
all_clusters.append(self.cluster3(f3))
|
| 241 |
+
if self.granularity >= 4:
|
| 242 |
+
f3_up = self.up(f3)
|
| 243 |
+
final_size = self.spatial_resolutions[-1]
|
| 244 |
+
f4 = self.conv4(self.c(f3_up, F.interpolate(
|
| 245 |
+
x, (final_size, final_size), mode="bilinear", align_corners=False)))
|
| 246 |
+
all_clusters.append(self.cluster4(f4))
|
| 247 |
+
|
| 248 |
+
avg_code = torch.cat(all_clusters, 4).mean(4)
|
| 249 |
+
|
| 250 |
+
if self.continuous:
|
| 251 |
+
clusters = avg_code
|
| 252 |
+
else:
|
| 253 |
+
clusters = torch.log_softmax(avg_code, 1)
|
| 254 |
+
|
| 255 |
+
return low_res_feats, clusters
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class DoubleConv(nn.Module):
|
| 259 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
| 260 |
+
|
| 261 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 262 |
+
super().__init__()
|
| 263 |
+
if not mid_channels:
|
| 264 |
+
mid_channels = out_channels
|
| 265 |
+
self.double_conv = nn.Sequential(
|
| 266 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
| 267 |
+
nn.BatchNorm2d(mid_channels),
|
| 268 |
+
nn.ReLU(),
|
| 269 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
| 270 |
+
nn.BatchNorm2d(out_channels),
|
| 271 |
+
nn.ReLU()
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def forward(self, x):
|
| 275 |
+
return self.double_conv(x)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def norm(t):
|
| 279 |
+
return F.normalize(t, dim=1, eps=1e-10)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def average_norm(t):
|
| 283 |
+
return t / t.square().sum(1, keepdim=True).sqrt().mean()
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def tensor_correlation(a, b):
|
| 287 |
+
return torch.einsum("nchw,ncij->nhwij", a, b)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def sample(t: torch.Tensor, coords: torch.Tensor):
|
| 291 |
+
return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@torch.jit.script
|
| 295 |
+
def super_perm(size: int, device: torch.device):
|
| 296 |
+
perm = torch.randperm(size, device=device, dtype=torch.long)
|
| 297 |
+
perm[perm == torch.arange(size, device=device)] += 1
|
| 298 |
+
return perm % size
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def sample_nonzero_locations(t, target_size):
|
| 302 |
+
nonzeros = torch.nonzero(t)
|
| 303 |
+
coords = torch.zeros(target_size, dtype=nonzeros.dtype, device=nonzeros.device)
|
| 304 |
+
n = target_size[1] * target_size[2]
|
| 305 |
+
for i in range(t.shape[0]):
|
| 306 |
+
selected_nonzeros = nonzeros[nonzeros[:, 0] == i]
|
| 307 |
+
if selected_nonzeros.shape[0] == 0:
|
| 308 |
+
selected_coords = torch.randint(t.shape[1], size=(n, 2), device=nonzeros.device)
|
| 309 |
+
else:
|
| 310 |
+
selected_coords = selected_nonzeros[torch.randint(len(selected_nonzeros), size=(n,)), 1:]
|
| 311 |
+
coords[i, :, :, :] = selected_coords.reshape(target_size[1], target_size[2], 2)
|
| 312 |
+
coords = coords.to(torch.float32) / t.shape[1]
|
| 313 |
+
coords = coords * 2 - 1
|
| 314 |
+
return torch.flip(coords, dims=[-1])
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class ContrastiveCorrelationLoss(nn.Module):
|
| 318 |
+
|
| 319 |
+
def __init__(self, cfg, ):
|
| 320 |
+
super(ContrastiveCorrelationLoss, self).__init__()
|
| 321 |
+
self.cfg = cfg
|
| 322 |
+
|
| 323 |
+
def standard_scale(self, t):
|
| 324 |
+
t1 = t - t.mean()
|
| 325 |
+
t2 = t1 / t1.std()
|
| 326 |
+
return t2
|
| 327 |
+
|
| 328 |
+
def helper(self, f1, f2, c1, c2, shift):
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
# Comes straight from backbone which is currently frozen. this saves mem.
|
| 331 |
+
fd = tensor_correlation(norm(f1), norm(f2))
|
| 332 |
+
|
| 333 |
+
if self.cfg.pointwise:
|
| 334 |
+
old_mean = fd.mean()
|
| 335 |
+
fd -= fd.mean([3, 4], keepdim=True)
|
| 336 |
+
fd = fd - fd.mean() + old_mean
|
| 337 |
+
|
| 338 |
+
cd = tensor_correlation(norm(c1), norm(c2))
|
| 339 |
+
|
| 340 |
+
if self.cfg.zero_clamp:
|
| 341 |
+
min_val = 0.0
|
| 342 |
+
else:
|
| 343 |
+
min_val = -9999.0
|
| 344 |
+
|
| 345 |
+
if self.cfg.stabalize:
|
| 346 |
+
loss = - cd.clamp(min_val, .8) * (fd - shift)
|
| 347 |
+
else:
|
| 348 |
+
loss = - cd.clamp(min_val) * (fd - shift)
|
| 349 |
+
|
| 350 |
+
return loss, cd
|
| 351 |
+
|
| 352 |
+
def forward(self,
|
| 353 |
+
orig_feats: torch.Tensor, orig_feats_pos: torch.Tensor,
|
| 354 |
+
orig_salience: torch.Tensor, orig_salience_pos: torch.Tensor,
|
| 355 |
+
orig_code: torch.Tensor, orig_code_pos: torch.Tensor,
|
| 356 |
+
):
|
| 357 |
+
|
| 358 |
+
coord_shape = [orig_feats.shape[0], self.cfg.feature_samples, self.cfg.feature_samples, 2]
|
| 359 |
+
|
| 360 |
+
if self.cfg.use_salience:
|
| 361 |
+
coords1_nonzero = sample_nonzero_locations(orig_salience, coord_shape)
|
| 362 |
+
coords2_nonzero = sample_nonzero_locations(orig_salience_pos, coord_shape)
|
| 363 |
+
coords1_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
| 364 |
+
coords2_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
| 365 |
+
mask = (torch.rand(coord_shape[:-1], device=orig_feats.device) > .1).unsqueeze(-1).to(torch.float32)
|
| 366 |
+
coords1 = coords1_nonzero * mask + coords1_reg * (1 - mask)
|
| 367 |
+
coords2 = coords2_nonzero * mask + coords2_reg * (1 - mask)
|
| 368 |
+
else:
|
| 369 |
+
coords1 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
| 370 |
+
coords2 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
| 371 |
+
|
| 372 |
+
feats = sample(orig_feats, coords1)
|
| 373 |
+
code = sample(orig_code, coords1)
|
| 374 |
+
|
| 375 |
+
feats_pos = sample(orig_feats_pos, coords2)
|
| 376 |
+
code_pos = sample(orig_code_pos, coords2)
|
| 377 |
+
|
| 378 |
+
pos_intra_loss, pos_intra_cd = self.helper(
|
| 379 |
+
feats, feats, code, code, self.cfg.pos_intra_shift)
|
| 380 |
+
pos_inter_loss, pos_inter_cd = self.helper(
|
| 381 |
+
feats, feats_pos, code, code_pos, self.cfg.pos_inter_shift)
|
| 382 |
+
|
| 383 |
+
neg_losses = []
|
| 384 |
+
neg_cds = []
|
| 385 |
+
for i in range(self.cfg.neg_samples):
|
| 386 |
+
perm_neg = super_perm(orig_feats.shape[0], orig_feats.device)
|
| 387 |
+
feats_neg = sample(orig_feats[perm_neg], coords2)
|
| 388 |
+
code_neg = sample(orig_code[perm_neg], coords2)
|
| 389 |
+
neg_inter_loss, neg_inter_cd = self.helper(
|
| 390 |
+
feats, feats_neg, code, code_neg, self.cfg.neg_inter_shift)
|
| 391 |
+
neg_losses.append(neg_inter_loss)
|
| 392 |
+
neg_cds.append(neg_inter_cd)
|
| 393 |
+
neg_inter_loss = torch.cat(neg_losses, axis=0)
|
| 394 |
+
neg_inter_cd = torch.cat(neg_cds, axis=0)
|
| 395 |
+
|
| 396 |
+
return (pos_intra_loss.mean(),
|
| 397 |
+
pos_intra_cd,
|
| 398 |
+
pos_inter_loss.mean(),
|
| 399 |
+
pos_inter_cd,
|
| 400 |
+
neg_inter_loss,
|
| 401 |
+
neg_inter_cd)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class Decoder(nn.Module):
|
| 405 |
+
def __init__(self, code_channels, feat_channels):
|
| 406 |
+
super().__init__()
|
| 407 |
+
self.linear = torch.nn.Conv2d(code_channels, feat_channels, (1, 1))
|
| 408 |
+
self.nonlinear = torch.nn.Sequential(
|
| 409 |
+
torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
|
| 410 |
+
torch.nn.ReLU(),
|
| 411 |
+
torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
|
| 412 |
+
torch.nn.ReLU(),
|
| 413 |
+
torch.nn.Conv2d(code_channels, feat_channels, (1, 1)))
|
| 414 |
+
|
| 415 |
+
def forward(self, x):
|
| 416 |
+
return self.linear(x) + self.nonlinear(x)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class NetWithActivations(torch.nn.Module):
|
| 420 |
+
def __init__(self, model, layer_nums):
|
| 421 |
+
super(NetWithActivations, self).__init__()
|
| 422 |
+
self.layers = nn.ModuleList(model.children())
|
| 423 |
+
self.layer_nums = []
|
| 424 |
+
for l in layer_nums:
|
| 425 |
+
if l < 0:
|
| 426 |
+
self.layer_nums.append(len(self.layers) + l)
|
| 427 |
+
else:
|
| 428 |
+
self.layer_nums.append(l)
|
| 429 |
+
self.layer_nums = set(sorted(self.layer_nums))
|
| 430 |
+
|
| 431 |
+
def forward(self, x):
|
| 432 |
+
activations = {}
|
| 433 |
+
for ln, l in enumerate(self.layers):
|
| 434 |
+
x = l(x)
|
| 435 |
+
if ln in self.layer_nums:
|
| 436 |
+
activations[ln] = x
|
| 437 |
+
return activations
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class ContrastiveCRFLoss(nn.Module):
|
| 441 |
+
|
| 442 |
+
def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift):
|
| 443 |
+
super(ContrastiveCRFLoss, self).__init__()
|
| 444 |
+
self.alpha = alpha
|
| 445 |
+
self.beta = beta
|
| 446 |
+
self.gamma = gamma
|
| 447 |
+
self.w1 = w1
|
| 448 |
+
self.w2 = w2
|
| 449 |
+
self.n_samples = n_samples
|
| 450 |
+
self.shift = shift
|
| 451 |
+
|
| 452 |
+
def forward(self, guidance, clusters):
|
| 453 |
+
device = clusters.device
|
| 454 |
+
assert (guidance.shape[0] == clusters.shape[0])
|
| 455 |
+
assert (guidance.shape[2:] == clusters.shape[2:])
|
| 456 |
+
h = guidance.shape[2]
|
| 457 |
+
w = guidance.shape[3]
|
| 458 |
+
|
| 459 |
+
coords = torch.cat([
|
| 460 |
+
torch.randint(0, h, size=[1, self.n_samples], device=device),
|
| 461 |
+
torch.randint(0, w, size=[1, self.n_samples], device=device)], 0)
|
| 462 |
+
|
| 463 |
+
selected_guidance = guidance[:, :, coords[0, :], coords[1, :]]
|
| 464 |
+
coord_diff = (coords.unsqueeze(-1) - coords.unsqueeze(1)).square().sum(0).unsqueeze(0)
|
| 465 |
+
guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(2)).square().sum(1)
|
| 466 |
+
|
| 467 |
+
sim_kernel = self.w1 * torch.exp(- coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) + \
|
| 468 |
+
self.w2 * torch.exp(- coord_diff / (2 * self.gamma)) - self.shift
|
| 469 |
+
|
| 470 |
+
selected_clusters = clusters[:, :, coords[0, :], coords[1, :]]
|
| 471 |
+
cluster_sims = torch.einsum("nka,nkb->nab", selected_clusters, selected_clusters)
|
| 472 |
+
return -(cluster_sims * sim_kernel)
|
biomap/output/img.png
ADDED
|
biomap/output/img_6.png
ADDED
|
biomap/output/label.png
ADDED
|
biomap/output/labeled_img.png
ADDED
|
biomap/plot_functions.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
import matplotlib as mpl
|
| 5 |
+
from utils import prep_for_plot
|
| 6 |
+
|
| 7 |
+
import torch.multiprocessing
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
# import matplotlib.pyplot as plt
|
| 10 |
+
from model import LitUnsupervisedSegmenter
|
| 11 |
+
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
|
| 12 |
+
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
|
| 13 |
+
cmap = mpl.colors.ListedColormap(colors)
|
| 14 |
+
#from train_segmentation import LitUnsupervisedSegmenter, cmap
|
| 15 |
+
|
| 16 |
+
from utils_gee import extract_img, transform_ee_img
|
| 17 |
+
|
| 18 |
+
import plotly.graph_objects as go
|
| 19 |
+
import plotly.express as px
|
| 20 |
+
import numpy as np
|
| 21 |
+
from plotly.subplots import make_subplots
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
|
| 28 |
+
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
|
| 29 |
+
scores_init = [2,3,4,3,1,4,0]
|
| 30 |
+
|
| 31 |
+
# Import model configs
|
| 32 |
+
hydra.initialize(config_path="configs", job_name="corine")
|
| 33 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
| 34 |
+
|
| 35 |
+
nbclasses = cfg.dir_dataset_n_classes
|
| 36 |
+
|
| 37 |
+
# Load Model
|
| 38 |
+
model_path = "checkpoint/model/model.pt"
|
| 39 |
+
saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
|
| 40 |
+
|
| 41 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
| 42 |
+
model.load_state_dict(saved_state_dict)
|
| 43 |
+
|
| 44 |
+
from PIL import Image
|
| 45 |
+
|
| 46 |
+
import hydra
|
| 47 |
+
|
| 48 |
+
from utils import prep_for_plot
|
| 49 |
+
|
| 50 |
+
import torch.multiprocessing
|
| 51 |
+
import torchvision.transforms as T
|
| 52 |
+
# import matplotlib.pyplot as plt
|
| 53 |
+
|
| 54 |
+
from model import LitUnsupervisedSegmenter
|
| 55 |
+
|
| 56 |
+
from utils_gee import extract_img, transform_ee_img
|
| 57 |
+
|
| 58 |
+
import plotly.graph_objects as go
|
| 59 |
+
import plotly.express as px
|
| 60 |
+
import numpy as np
|
| 61 |
+
from plotly.subplots import make_subplots
|
| 62 |
+
|
| 63 |
+
import os
|
| 64 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
|
| 68 |
+
cmap = mpl.colors.ListedColormap(colors)
|
| 69 |
+
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
|
| 70 |
+
scores_init = [2,3,4,3,1,4,0]
|
| 71 |
+
|
| 72 |
+
# Import model configs
|
| 73 |
+
#hydra.initialize(config_path="configs", job_name="corine")
|
| 74 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
| 75 |
+
|
| 76 |
+
nbclasses = cfg.dir_dataset_n_classes
|
| 77 |
+
|
| 78 |
+
# Load Model
|
| 79 |
+
model_path = "checkpoint/model/model.pt"
|
| 80 |
+
saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
|
| 81 |
+
|
| 82 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
| 83 |
+
model.load_state_dict(saved_state_dict)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
#normalize img
|
| 87 |
+
preprocess = T.Compose([
|
| 88 |
+
T.ToPILImage(),
|
| 89 |
+
T.Resize((320,320)),
|
| 90 |
+
# T.CenterCrop(224),
|
| 91 |
+
T.ToTensor(),
|
| 92 |
+
T.Normalize(
|
| 93 |
+
mean=[0.485, 0.456, 0.406],
|
| 94 |
+
std=[0.229, 0.224, 0.225]
|
| 95 |
+
)
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
# Function that look for img on EE and segment it
|
| 99 |
+
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
|
| 100 |
+
|
| 101 |
+
def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
|
| 102 |
+
if how == 'month':
|
| 103 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
|
| 104 |
+
elif how == 'year' :
|
| 105 |
+
if year_end == None :
|
| 106 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
| 107 |
+
else :
|
| 108 |
+
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
img_test= transform_ee_img(img, max = 0.25)
|
| 112 |
+
|
| 113 |
+
# Preprocess opened img
|
| 114 |
+
x = preprocess(img_test)
|
| 115 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
| 116 |
+
# model=model.cpu()
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
feats, code = model.net(x)
|
| 120 |
+
linear_preds = model.linear_probe(x, code)
|
| 121 |
+
linear_preds = linear_preds.argmax(1)
|
| 122 |
+
outputs = {
|
| 123 |
+
'img': x[:model.cfg.n_images].detach().cpu(),
|
| 124 |
+
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
|
| 125 |
+
}
|
| 126 |
+
return outputs
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Function that look for all img on EE and extract all segments with the date as first output arg
|
| 130 |
+
|
| 131 |
+
def segment_group(location, start_date, end_date, how = 'month') :
|
| 132 |
+
outputs = []
|
| 133 |
+
st_month = int(start_date[5:7])
|
| 134 |
+
end_month = int(end_date[5:7])
|
| 135 |
+
|
| 136 |
+
st_year = int(start_date[0:4])
|
| 137 |
+
end_year = int(end_date[0:4])
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
for year in range(st_year, end_year+1) :
|
| 142 |
+
|
| 143 |
+
if year != end_year :
|
| 144 |
+
last = 12
|
| 145 |
+
else :
|
| 146 |
+
last = end_month
|
| 147 |
+
|
| 148 |
+
if year != st_year:
|
| 149 |
+
start = 1
|
| 150 |
+
else :
|
| 151 |
+
start = st_month
|
| 152 |
+
|
| 153 |
+
if how == 'month' :
|
| 154 |
+
for month in range(start, last + 1):
|
| 155 |
+
month_str = f"{month:0>2d}"
|
| 156 |
+
year_str = str(year)
|
| 157 |
+
|
| 158 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
|
| 159 |
+
|
| 160 |
+
elif how == 'year' :
|
| 161 |
+
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
|
| 162 |
+
|
| 163 |
+
elif how == '2months' :
|
| 164 |
+
for month in range(start, last + 1):
|
| 165 |
+
month_str = f"{month:0>2d}"
|
| 166 |
+
year_str = str(year)
|
| 167 |
+
month_end = (month) % 12 +1
|
| 168 |
+
if month_end < month :
|
| 169 |
+
year_end = year +1
|
| 170 |
+
else :
|
| 171 |
+
year_end = year
|
| 172 |
+
month_end= f"{month_end:0>2d}"
|
| 173 |
+
year_end = str(year_end)
|
| 174 |
+
|
| 175 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
return outputs
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Function that transforms an output to PIL images
|
| 182 |
+
|
| 183 |
+
def transform_to_pil(outputs,alpha=0.3):
|
| 184 |
+
# Transform img with torch
|
| 185 |
+
img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
|
| 186 |
+
img=T.ToPILImage()(img)
|
| 187 |
+
|
| 188 |
+
# Transform label by saving it then open it
|
| 189 |
+
# label = outputs['linear_preds'][0]
|
| 190 |
+
# plt.imsave('label.png',label,cmap=cmap)
|
| 191 |
+
# label = Image.open('label.png')
|
| 192 |
+
|
| 193 |
+
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
|
| 194 |
+
labels = np.array(outputs['linear_preds'][0])-1
|
| 195 |
+
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# Overlay labels with img wit alpha
|
| 199 |
+
background = img.convert("RGBA")
|
| 200 |
+
overlay = label.convert("RGBA")
|
| 201 |
+
|
| 202 |
+
labeled_img = Image.blend(background, overlay, alpha)
|
| 203 |
+
|
| 204 |
+
return img, label, labeled_img
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
|
| 209 |
+
|
| 210 |
+
def values_from_output(output):
|
| 211 |
+
imgs = transform_to_pil(output,alpha = 0.3)
|
| 212 |
+
|
| 213 |
+
img = imgs[0]
|
| 214 |
+
img = np.array(img.convert('RGB'))
|
| 215 |
+
|
| 216 |
+
labeled_img = imgs[2]
|
| 217 |
+
labeled_img = np.array(labeled_img.convert('RGB'))
|
| 218 |
+
|
| 219 |
+
nb_values = []
|
| 220 |
+
for i in range(7):
|
| 221 |
+
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
|
| 222 |
+
|
| 223 |
+
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
|
| 224 |
+
|
| 225 |
+
return img, labeled_img, nb_values, score
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# Function that extract from outputs (from segment_group function) all dates/ all images
|
| 229 |
+
def values_from_outputs(outputs) :
|
| 230 |
+
months = []
|
| 231 |
+
imgs = []
|
| 232 |
+
imgs_label = []
|
| 233 |
+
nb_values = []
|
| 234 |
+
scores = []
|
| 235 |
+
|
| 236 |
+
for output in outputs:
|
| 237 |
+
img, labeled_img, nb_value, score = values_from_output(output[1])
|
| 238 |
+
months.append(output[0])
|
| 239 |
+
imgs.append(img)
|
| 240 |
+
imgs_label.append(labeled_img)
|
| 241 |
+
nb_values.append(nb_value)
|
| 242 |
+
scores.append(score)
|
| 243 |
+
|
| 244 |
+
return months, imgs, imgs_label, nb_values, scores
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
|
| 249 |
+
|
| 250 |
+
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
|
| 251 |
+
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
|
| 252 |
+
|
| 253 |
+
# Scores
|
| 254 |
+
scatters = []
|
| 255 |
+
temp = []
|
| 256 |
+
for score in scores :
|
| 257 |
+
temp_score = []
|
| 258 |
+
temp_date = []
|
| 259 |
+
score = scores[i]
|
| 260 |
+
temp.append(score)
|
| 261 |
+
text_temp = ["" for i in temp]
|
| 262 |
+
text_temp[-1] = str(round(score,2))
|
| 263 |
+
scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# Scores
|
| 267 |
+
fig = make_subplots(
|
| 268 |
+
rows=1, cols=4,
|
| 269 |
+
# specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]]
|
| 270 |
+
# row_heights=[0.8, 0.2],
|
| 271 |
+
column_widths = [0.6, 0.6,0.3, 0.3],
|
| 272 |
+
subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores")
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
|
| 276 |
+
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
|
| 277 |
+
|
| 278 |
+
fig.add_trace(go.Pie(labels = class_names,
|
| 279 |
+
values = nb_values[0],
|
| 280 |
+
marker_colors = colors,
|
| 281 |
+
name="Segment repartition",
|
| 282 |
+
textposition='inside',
|
| 283 |
+
texttemplate = "%{percent:.0%}",
|
| 284 |
+
textfont_size=14
|
| 285 |
+
),
|
| 286 |
+
row=1, col=3)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
fig.add_trace(scatters[0], row=1, col=4)
|
| 290 |
+
# fig.add_annotation(text='score:' + str(scores[0]),
|
| 291 |
+
# showarrow=False,
|
| 292 |
+
# row=2, col=2)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
number_frames = len(imgs)
|
| 296 |
+
frames = [dict(
|
| 297 |
+
name = k,
|
| 298 |
+
data = [ fig2["frames"][k]["data"][0],
|
| 299 |
+
fig3["frames"][k]["data"][0],
|
| 300 |
+
go.Pie(labels = class_names,
|
| 301 |
+
values = nb_values[k],
|
| 302 |
+
marker_colors = colors,
|
| 303 |
+
name="Segment repartition",
|
| 304 |
+
textposition='inside',
|
| 305 |
+
texttemplate = "%{percent:.0%}",
|
| 306 |
+
textfont_size=14
|
| 307 |
+
),
|
| 308 |
+
scatters[k]
|
| 309 |
+
],
|
| 310 |
+
traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
|
| 311 |
+
# that are updated by the above three go.Scatter instances
|
| 312 |
+
) for k in range(number_frames)]
|
| 313 |
+
|
| 314 |
+
updatemenus = [dict(type='buttons',
|
| 315 |
+
buttons=[dict(label='Play',
|
| 316 |
+
method='animate',
|
| 317 |
+
args=[[f'{k}' for k in range(number_frames)],
|
| 318 |
+
dict(frame=dict(duration=500, redraw=False),
|
| 319 |
+
transition=dict(duration=0),
|
| 320 |
+
easing='linear',
|
| 321 |
+
fromcurrent=True,
|
| 322 |
+
mode='immediate'
|
| 323 |
+
)])],
|
| 324 |
+
direction= 'left',
|
| 325 |
+
pad=dict(r= 10, t=85),
|
| 326 |
+
showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
sliders = [{'yanchor': 'top',
|
| 330 |
+
'xanchor': 'left',
|
| 331 |
+
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
|
| 332 |
+
'transition': {'duration': 500.0, 'easing': 'linear'},
|
| 333 |
+
'pad': {'b': 10, 't': 50},
|
| 334 |
+
'len': 0.9, 'x': 0.1, 'y': 0,
|
| 335 |
+
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
|
| 336 |
+
'transition': {'duration': 0, 'easing': 'linear'}}],
|
| 337 |
+
'label': months[k], 'method': 'animate'} for k in range(number_frames)
|
| 338 |
+
]}]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
fig.update(frames=frames)
|
| 342 |
+
|
| 343 |
+
for i,fr in enumerate(fig["frames"]):
|
| 344 |
+
fr.update(
|
| 345 |
+
layout={
|
| 346 |
+
"xaxis": {
|
| 347 |
+
"range": [0,imgs[0].shape[1]+i/100000]
|
| 348 |
+
},
|
| 349 |
+
"yaxis": {
|
| 350 |
+
"range": [imgs[0].shape[0]+i/100000,0]
|
| 351 |
+
},
|
| 352 |
+
})
|
| 353 |
+
|
| 354 |
+
fr.update(layout_title_text= months[i])
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
fig.update(layout_title_text= 'tot')
|
| 358 |
+
fig.update(
|
| 359 |
+
layout={
|
| 360 |
+
"xaxis": {
|
| 361 |
+
"range": [0,imgs[0].shape[1]+i/100000],
|
| 362 |
+
'showgrid': False, # thin lines in the background
|
| 363 |
+
'zeroline': False, # thick line at x=0
|
| 364 |
+
'visible': False, # numbers below
|
| 365 |
+
},
|
| 366 |
+
|
| 367 |
+
"yaxis": {
|
| 368 |
+
"range": [imgs[0].shape[0]+i/100000,0],
|
| 369 |
+
'showgrid': False, # thin lines in the background
|
| 370 |
+
'zeroline': False, # thick line at y=0
|
| 371 |
+
'visible': False,},
|
| 372 |
+
|
| 373 |
+
"xaxis3": {
|
| 374 |
+
"range": [0,len(scores)+1],
|
| 375 |
+
'autorange': False, # thin lines in the background
|
| 376 |
+
'showgrid': False, # thin lines in the background
|
| 377 |
+
'zeroline': False, # thick line at y=0
|
| 378 |
+
'visible': False
|
| 379 |
+
},
|
| 380 |
+
|
| 381 |
+
"yaxis3": {
|
| 382 |
+
"range": [0,1.5],
|
| 383 |
+
'autorange': False,
|
| 384 |
+
'showgrid': False, # thin lines in the background
|
| 385 |
+
'zeroline': False, # thick line at y=0
|
| 386 |
+
'visible': False # thin lines in the background
|
| 387 |
+
}
|
| 388 |
+
},
|
| 389 |
+
legend=dict(
|
| 390 |
+
yanchor="bottom",
|
| 391 |
+
y=0.99,
|
| 392 |
+
xanchor="center",
|
| 393 |
+
x=0.01
|
| 394 |
+
)
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
fig.update_layout(updatemenus=updatemenus,
|
| 399 |
+
sliders=sliders)
|
| 400 |
+
|
| 401 |
+
fig.update_layout(margin=dict(b=0, r=0))
|
| 402 |
+
|
| 403 |
+
# fig.show() #in jupyter notebook
|
| 404 |
+
|
| 405 |
+
return fig
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# Last function (global one)
|
| 410 |
+
# how = 'month' or '2months' or 'year'
|
| 411 |
+
|
| 412 |
+
def segment_region(location, start_date, end_date, how = 'month'):
|
| 413 |
+
|
| 414 |
+
#extract the outputs for each image
|
| 415 |
+
outputs = segment_group(location, start_date, end_date, how = how)
|
| 416 |
+
|
| 417 |
+
#extract the intersting values from image
|
| 418 |
+
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
|
| 419 |
+
|
| 420 |
+
#Create the figure
|
| 421 |
+
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
|
| 422 |
+
|
| 423 |
+
return fig
|
| 424 |
+
#normalize img
|
| 425 |
+
preprocess = T.Compose([
|
| 426 |
+
T.ToPILImage(),
|
| 427 |
+
T.Resize((320,320)),
|
| 428 |
+
# T.CenterCrop(224),
|
| 429 |
+
T.ToTensor(),
|
| 430 |
+
T.Normalize(
|
| 431 |
+
mean=[0.485, 0.456, 0.406],
|
| 432 |
+
std=[0.229, 0.224, 0.225]
|
| 433 |
+
)
|
| 434 |
+
])
|
| 435 |
+
|
| 436 |
+
# Function that look for img on EE and segment it
|
| 437 |
+
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
|
| 438 |
+
|
| 439 |
+
def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
|
| 440 |
+
if how == 'month':
|
| 441 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
|
| 442 |
+
elif how == 'year' :
|
| 443 |
+
if year_end == None :
|
| 444 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
| 445 |
+
else :
|
| 446 |
+
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
img_test= transform_ee_img(img, max = 0.25)
|
| 450 |
+
|
| 451 |
+
# Preprocess opened img
|
| 452 |
+
x = preprocess(img_test)
|
| 453 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
| 454 |
+
# model=model.cpu()
|
| 455 |
+
|
| 456 |
+
with torch.no_grad():
|
| 457 |
+
feats, code = model.net(x)
|
| 458 |
+
linear_preds = model.linear_probe(x, code)
|
| 459 |
+
linear_preds = linear_preds.argmax(1)
|
| 460 |
+
outputs = {
|
| 461 |
+
'img': x[:model.cfg.n_images].detach().cpu(),
|
| 462 |
+
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
|
| 463 |
+
}
|
| 464 |
+
return outputs
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# Function that look for all img on EE and extract all segments with the date as first output arg
|
| 468 |
+
|
| 469 |
+
def segment_group(location, start_date, end_date, how = 'month') :
|
| 470 |
+
outputs = []
|
| 471 |
+
st_month = int(start_date[5:7])
|
| 472 |
+
end_month = int(end_date[5:7])
|
| 473 |
+
|
| 474 |
+
st_year = int(start_date[0:4])
|
| 475 |
+
end_year = int(end_date[0:4])
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
for year in range(st_year, end_year+1) :
|
| 480 |
+
|
| 481 |
+
if year != end_year :
|
| 482 |
+
last = 12
|
| 483 |
+
else :
|
| 484 |
+
last = end_month
|
| 485 |
+
|
| 486 |
+
if year != st_year:
|
| 487 |
+
start = 1
|
| 488 |
+
else :
|
| 489 |
+
start = st_month
|
| 490 |
+
|
| 491 |
+
if how == 'month' :
|
| 492 |
+
for month in range(start, last + 1):
|
| 493 |
+
month_str = f"{month:0>2d}"
|
| 494 |
+
year_str = str(year)
|
| 495 |
+
|
| 496 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
|
| 497 |
+
|
| 498 |
+
elif how == 'year' :
|
| 499 |
+
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
|
| 500 |
+
|
| 501 |
+
elif how == '2months' :
|
| 502 |
+
for month in range(start, last + 1):
|
| 503 |
+
month_str = f"{month:0>2d}"
|
| 504 |
+
year_str = str(year)
|
| 505 |
+
month_end = (month) % 12 +1
|
| 506 |
+
if month_end < month :
|
| 507 |
+
year_end = year +1
|
| 508 |
+
else :
|
| 509 |
+
year_end = year
|
| 510 |
+
month_end= f"{month_end:0>2d}"
|
| 511 |
+
year_end = str(year_end)
|
| 512 |
+
|
| 513 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
return outputs
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
# Function that transforms an output to PIL images
|
| 520 |
+
|
| 521 |
+
def transform_to_pil(outputs,alpha=0.3):
|
| 522 |
+
# Transform img with torch
|
| 523 |
+
img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
|
| 524 |
+
img=T.ToPILImage()(img)
|
| 525 |
+
|
| 526 |
+
# Transform label by saving it then open it
|
| 527 |
+
# label = outputs['linear_preds'][0]
|
| 528 |
+
# plt.imsave('label.png',label,cmap=cmap)
|
| 529 |
+
# label = Image.open('label.png')
|
| 530 |
+
|
| 531 |
+
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
|
| 532 |
+
labels = np.array(outputs['linear_preds'][0])-1
|
| 533 |
+
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# Overlay labels with img wit alpha
|
| 537 |
+
background = img.convert("RGBA")
|
| 538 |
+
overlay = label.convert("RGBA")
|
| 539 |
+
|
| 540 |
+
labeled_img = Image.blend(background, overlay, alpha)
|
| 541 |
+
|
| 542 |
+
return img, label, labeled_img
|
| 543 |
+
|
| 544 |
+
def values_from_output(output):
|
| 545 |
+
imgs = transform_to_pil(output,alpha = 0.3)
|
| 546 |
+
|
| 547 |
+
img = imgs[0]
|
| 548 |
+
img = np.array(img.convert('RGB'))
|
| 549 |
+
|
| 550 |
+
labeled_img = imgs[2]
|
| 551 |
+
labeled_img = np.array(labeled_img.convert('RGB'))
|
| 552 |
+
|
| 553 |
+
nb_values = []
|
| 554 |
+
for i in range(7):
|
| 555 |
+
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
|
| 556 |
+
|
| 557 |
+
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
|
| 558 |
+
|
| 559 |
+
return img, labeled_img, nb_values, score
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
# Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
# Function that extract from outputs (from segment_group function) all dates/ all images
|
| 567 |
+
def values_from_outputs(outputs) :
|
| 568 |
+
months = []
|
| 569 |
+
imgs = []
|
| 570 |
+
imgs_label = []
|
| 571 |
+
nb_values = []
|
| 572 |
+
scores = []
|
| 573 |
+
|
| 574 |
+
for output in outputs:
|
| 575 |
+
img, labeled_img, nb_value, score = values_from_output(output[1])
|
| 576 |
+
months.append(output[0])
|
| 577 |
+
imgs.append(img)
|
| 578 |
+
imgs_label.append(labeled_img)
|
| 579 |
+
nb_values.append(nb_value)
|
| 580 |
+
scores.append(score)
|
| 581 |
+
|
| 582 |
+
return months, imgs, imgs_label, nb_values, scores
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
|
| 587 |
+
|
| 588 |
+
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
|
| 589 |
+
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
|
| 590 |
+
|
| 591 |
+
# Scores
|
| 592 |
+
scatters = []
|
| 593 |
+
temp = []
|
| 594 |
+
for score in scores :
|
| 595 |
+
temp_score = []
|
| 596 |
+
temp_date = []
|
| 597 |
+
#score = scores[i]
|
| 598 |
+
temp.append(score)
|
| 599 |
+
n = len(temp)
|
| 600 |
+
text_temp = ["" for i in temp]
|
| 601 |
+
text_temp[-1] = str(round(score,2))
|
| 602 |
+
scatters.append(go.Scatter(x=[0,1], y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
|
| 603 |
+
print(text_temp)
|
| 604 |
+
|
| 605 |
+
# Scores
|
| 606 |
+
fig = make_subplots(
|
| 607 |
+
rows=1, cols=4,
|
| 608 |
+
specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
|
| 609 |
+
subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
|
| 613 |
+
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
|
| 614 |
+
|
| 615 |
+
fig.add_trace(go.Pie(labels = class_names,
|
| 616 |
+
values = nb_values[0],
|
| 617 |
+
marker_colors = colors,
|
| 618 |
+
name="Segment repartition",
|
| 619 |
+
textposition='inside',
|
| 620 |
+
texttemplate = "%{percent:.0%}",
|
| 621 |
+
textfont_size=14
|
| 622 |
+
),
|
| 623 |
+
row=1, col=3)
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
fig.add_trace(scatters[0], row=1, col=4)
|
| 627 |
+
fig.update_traces(showlegend=False, selector=dict(type='scatter'))
|
| 628 |
+
#fig.update_traces(, selector=dict(type='scatter'))
|
| 629 |
+
# fig.add_annotation(text='score:' + str(scores[0]),
|
| 630 |
+
# showarrow=False,
|
| 631 |
+
# row=2, col=2)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
number_frames = len(imgs)
|
| 635 |
+
frames = [dict(
|
| 636 |
+
name = k,
|
| 637 |
+
data = [ fig2["frames"][k]["data"][0],
|
| 638 |
+
fig3["frames"][k]["data"][0],
|
| 639 |
+
go.Pie(labels = class_names,
|
| 640 |
+
values = nb_values[k],
|
| 641 |
+
marker_colors = colors,
|
| 642 |
+
name="Segment repartition",
|
| 643 |
+
textposition='inside',
|
| 644 |
+
texttemplate = "%{percent:.0%}",
|
| 645 |
+
textfont_size=14
|
| 646 |
+
),
|
| 647 |
+
scatters[k]
|
| 648 |
+
],
|
| 649 |
+
traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
|
| 650 |
+
# that are updated by the above three go.Scatter instances
|
| 651 |
+
) for k in range(number_frames)]
|
| 652 |
+
|
| 653 |
+
updatemenus = [dict(type='buttons',
|
| 654 |
+
buttons=[dict(label='Play',
|
| 655 |
+
method='animate',
|
| 656 |
+
args=[[f'{k}' for k in range(number_frames)],
|
| 657 |
+
dict(frame=dict(duration=500, redraw=False),
|
| 658 |
+
transition=dict(duration=0),
|
| 659 |
+
easing='linear',
|
| 660 |
+
fromcurrent=True,
|
| 661 |
+
mode='immediate'
|
| 662 |
+
)])],
|
| 663 |
+
direction= 'left',
|
| 664 |
+
pad=dict(r= 10, t=85),
|
| 665 |
+
showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
|
| 666 |
+
]
|
| 667 |
+
|
| 668 |
+
sliders = [{'yanchor': 'top',
|
| 669 |
+
'xanchor': 'left',
|
| 670 |
+
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
|
| 671 |
+
'transition': {'duration': 500.0, 'easing': 'linear'},
|
| 672 |
+
'pad': {'b': 10, 't': 50},
|
| 673 |
+
'len': 0.9, 'x': 0.1, 'y': 0,
|
| 674 |
+
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
|
| 675 |
+
'transition': {'duration': 0, 'easing': 'linear'}}],
|
| 676 |
+
'label': months[k], 'method': 'animate'} for k in range(number_frames)
|
| 677 |
+
]}]
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
fig.update(frames=frames)
|
| 681 |
+
|
| 682 |
+
for i,fr in enumerate(fig["frames"]):
|
| 683 |
+
fr.update(
|
| 684 |
+
layout={
|
| 685 |
+
"xaxis": {
|
| 686 |
+
"range": [0,imgs[0].shape[1]+i/100000]
|
| 687 |
+
},
|
| 688 |
+
"yaxis": {
|
| 689 |
+
"range": [imgs[0].shape[0]+i/100000,0]
|
| 690 |
+
},
|
| 691 |
+
})
|
| 692 |
+
|
| 693 |
+
fr.update(layout_title_text= months[i])
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
fig.update(layout_title_text= months[0])
|
| 697 |
+
fig.update(
|
| 698 |
+
layout={
|
| 699 |
+
"xaxis": {
|
| 700 |
+
"range": [0,imgs[0].shape[1]+i/100000],
|
| 701 |
+
'showgrid': False, # thin lines in the background
|
| 702 |
+
'zeroline': False, # thick line at x=0
|
| 703 |
+
'visible': False, # numbers below
|
| 704 |
+
},
|
| 705 |
+
|
| 706 |
+
"yaxis": {
|
| 707 |
+
"range": [imgs[0].shape[0]+i/100000,0],
|
| 708 |
+
'showgrid': False, # thin lines in the background
|
| 709 |
+
'zeroline': False, # thick line at y=0
|
| 710 |
+
'visible': False,},
|
| 711 |
+
|
| 712 |
+
"xaxis2": {
|
| 713 |
+
"range": [0,imgs[0].shape[1]+i/100000],
|
| 714 |
+
'showgrid': False, # thin lines in the background
|
| 715 |
+
'zeroline': False, # thick line at x=0
|
| 716 |
+
'visible': False, # numbers below
|
| 717 |
+
},
|
| 718 |
+
|
| 719 |
+
"yaxis2": {
|
| 720 |
+
"range": [imgs[0].shape[0]+i/100000,0],
|
| 721 |
+
'showgrid': False, # thin lines in the background
|
| 722 |
+
'zeroline': False, # thick line at y=0
|
| 723 |
+
'visible': False,},
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
"xaxis3": {
|
| 727 |
+
"range": [0,len(scores)+1],
|
| 728 |
+
'autorange': False, # thin lines in the background
|
| 729 |
+
'showgrid': False, # thin lines in the background
|
| 730 |
+
'zeroline': False, # thick line at y=0
|
| 731 |
+
'visible': False
|
| 732 |
+
},
|
| 733 |
+
|
| 734 |
+
"yaxis3": {
|
| 735 |
+
"range": [0,1.5],
|
| 736 |
+
'autorange': False,
|
| 737 |
+
'showgrid': False, # thin lines in the background
|
| 738 |
+
'zeroline': False, # thick line at y=0
|
| 739 |
+
'visible': False # thin lines in the background
|
| 740 |
+
}
|
| 741 |
+
}
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
fig.update_layout(updatemenus=updatemenus,
|
| 746 |
+
sliders=sliders,
|
| 747 |
+
legend=dict(
|
| 748 |
+
yanchor= 'top',
|
| 749 |
+
xanchor= 'left',
|
| 750 |
+
orientation="h")
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
fig.update_layout(margin=dict(b=0, r=0))
|
| 755 |
+
|
| 756 |
+
# fig.show() #in jupyter notebook
|
| 757 |
+
|
| 758 |
+
return fig
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
# Last function (global one)
|
| 763 |
+
# how = 'month' or '2months' or 'year'
|
| 764 |
+
|
| 765 |
+
def segment_region(latitude, longitude, start_date, end_date, how = 'month'):
|
| 766 |
+
location = [float(latitude),float(longitude)]
|
| 767 |
+
how = how[0]
|
| 768 |
+
#extract the outputs for each image
|
| 769 |
+
outputs = segment_group(location, start_date, end_date, how = how)
|
| 770 |
+
|
| 771 |
+
#extract the intersting values from image
|
| 772 |
+
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
#Create the figure
|
| 776 |
+
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
|
| 777 |
+
|
| 778 |
+
return fig
|
biomap/train.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
from modules import *
|
| 3 |
+
from data import *
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import hydra
|
| 8 |
+
from omegaconf import DictConfig, OmegaConf
|
| 9 |
+
import pytorch_lightning as pl
|
| 10 |
+
from pytorch_lightning import Trainer
|
| 11 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 12 |
+
from pytorch_lightning.utilities.seed import seed_everything
|
| 13 |
+
import torch.multiprocessing
|
| 14 |
+
import seaborn as sns
|
| 15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 16 |
+
import sys
|
| 17 |
+
import pdb
|
| 18 |
+
import matplotlib as mpl
|
| 19 |
+
from skimage import measure
|
| 20 |
+
from scipy.stats import mode as statsmode
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
import unet
|
| 23 |
+
import pdb
|
| 24 |
+
|
| 25 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
| 26 |
+
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
|
| 27 |
+
class_names = (
|
| 28 |
+
"Buildings",
|
| 29 |
+
"Cultivation",
|
| 30 |
+
"Natural green",
|
| 31 |
+
"Wetland",
|
| 32 |
+
"Water",
|
| 33 |
+
"Infrastructure",
|
| 34 |
+
"Background",
|
| 35 |
+
)
|
| 36 |
+
bounds = list(np.arange(len(class_names) + 1) + 1)
|
| 37 |
+
cmap = mpl.colors.ListedColormap(colors)
|
| 38 |
+
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def retouch_label(pred_label, true_label):
|
| 42 |
+
retouched_label = pred_label + 0
|
| 43 |
+
blobs = measure.label(retouched_label)
|
| 44 |
+
for idx in np.unique(blobs):
|
| 45 |
+
# most frequent label class in this blob
|
| 46 |
+
retouched_label[blobs == idx] = statsmode(true_label[blobs == idx])[0][0]
|
| 47 |
+
return retouched_label
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_class_labels(dataset_name):
|
| 51 |
+
if dataset_name.startswith("cityscapes"):
|
| 52 |
+
return [
|
| 53 |
+
"road",
|
| 54 |
+
"sidewalk",
|
| 55 |
+
"parking",
|
| 56 |
+
"rail track",
|
| 57 |
+
"building",
|
| 58 |
+
"wall",
|
| 59 |
+
"fence",
|
| 60 |
+
"guard rail",
|
| 61 |
+
"bridge",
|
| 62 |
+
"tunnel",
|
| 63 |
+
"pole",
|
| 64 |
+
"polegroup",
|
| 65 |
+
"traffic light",
|
| 66 |
+
"traffic sign",
|
| 67 |
+
"vegetation",
|
| 68 |
+
"terrain",
|
| 69 |
+
"sky",
|
| 70 |
+
"person",
|
| 71 |
+
"rider",
|
| 72 |
+
"car",
|
| 73 |
+
"truck",
|
| 74 |
+
"bus",
|
| 75 |
+
"caravan",
|
| 76 |
+
"trailer",
|
| 77 |
+
"train",
|
| 78 |
+
"motorcycle",
|
| 79 |
+
"bicycle",
|
| 80 |
+
]
|
| 81 |
+
elif dataset_name == "cocostuff27":
|
| 82 |
+
return [
|
| 83 |
+
"electronic",
|
| 84 |
+
"appliance",
|
| 85 |
+
"food",
|
| 86 |
+
"furniture",
|
| 87 |
+
"indoor",
|
| 88 |
+
"kitchen",
|
| 89 |
+
"accessory",
|
| 90 |
+
"animal",
|
| 91 |
+
"outdoor",
|
| 92 |
+
"person",
|
| 93 |
+
"sports",
|
| 94 |
+
"vehicle",
|
| 95 |
+
"ceiling",
|
| 96 |
+
"floor",
|
| 97 |
+
"food",
|
| 98 |
+
"furniture",
|
| 99 |
+
"rawmaterial",
|
| 100 |
+
"textile",
|
| 101 |
+
"wall",
|
| 102 |
+
"window",
|
| 103 |
+
"building",
|
| 104 |
+
"ground",
|
| 105 |
+
"plant",
|
| 106 |
+
"sky",
|
| 107 |
+
"solid",
|
| 108 |
+
"structural",
|
| 109 |
+
"water",
|
| 110 |
+
]
|
| 111 |
+
elif dataset_name == "voc":
|
| 112 |
+
return [
|
| 113 |
+
"background",
|
| 114 |
+
"aeroplane",
|
| 115 |
+
"bicycle",
|
| 116 |
+
"bird",
|
| 117 |
+
"boat",
|
| 118 |
+
"bottle",
|
| 119 |
+
"bus",
|
| 120 |
+
"car",
|
| 121 |
+
"cat",
|
| 122 |
+
"chair",
|
| 123 |
+
"cow",
|
| 124 |
+
"diningtable",
|
| 125 |
+
"dog",
|
| 126 |
+
"horse",
|
| 127 |
+
"motorbike",
|
| 128 |
+
"person",
|
| 129 |
+
"pottedplant",
|
| 130 |
+
"sheep",
|
| 131 |
+
"sofa",
|
| 132 |
+
"train",
|
| 133 |
+
"tvmonitor",
|
| 134 |
+
]
|
| 135 |
+
elif dataset_name == "potsdam":
|
| 136 |
+
return ["roads and cars", "buildings and clutter", "trees and vegetation"]
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError("Unknown Dataset {}".format(dataset_name))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@hydra.main(config_path="configs", config_name="train_config.yml")
|
| 142 |
+
def my_app(cfg: DictConfig) -> None:
|
| 143 |
+
OmegaConf.set_struct(cfg, False)
|
| 144 |
+
print(OmegaConf.to_yaml(cfg))
|
| 145 |
+
pytorch_data_dir = cfg.pytorch_data_dir
|
| 146 |
+
data_dir = join(cfg.output_root, "data")
|
| 147 |
+
log_dir = join(cfg.output_root, "logs")
|
| 148 |
+
checkpoint_dir = join(cfg.output_root, "checkpoints")
|
| 149 |
+
|
| 150 |
+
prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name)
|
| 151 |
+
name = "{}_date_{}".format(prefix, datetime.now().strftime("%b%d_%H-%M-%S"))
|
| 152 |
+
cfg.full_name = prefix
|
| 153 |
+
|
| 154 |
+
os.makedirs(data_dir, exist_ok=True)
|
| 155 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 156 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 157 |
+
|
| 158 |
+
seed_everything(seed=0)
|
| 159 |
+
|
| 160 |
+
print(data_dir)
|
| 161 |
+
print(cfg.output_root)
|
| 162 |
+
|
| 163 |
+
geometric_transforms = T.Compose(
|
| 164 |
+
[T.RandomHorizontalFlip(), T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))]
|
| 165 |
+
)
|
| 166 |
+
photometric_transforms = T.Compose(
|
| 167 |
+
[
|
| 168 |
+
T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
|
| 169 |
+
T.RandomGrayscale(0.2),
|
| 170 |
+
T.RandomApply([T.GaussianBlur((5, 5))]),
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
sys.stdout.flush()
|
| 175 |
+
|
| 176 |
+
train_dataset = ContrastiveSegDataset(
|
| 177 |
+
pytorch_data_dir=pytorch_data_dir,
|
| 178 |
+
dataset_name=cfg.dataset_name,
|
| 179 |
+
crop_type=cfg.crop_type,
|
| 180 |
+
image_set="train",
|
| 181 |
+
transform=get_transform(cfg.res, False, cfg.loader_crop_type),
|
| 182 |
+
target_transform=get_transform(cfg.res, True, cfg.loader_crop_type),
|
| 183 |
+
cfg=cfg,
|
| 184 |
+
aug_geometric_transform=geometric_transforms,
|
| 185 |
+
aug_photometric_transform=photometric_transforms,
|
| 186 |
+
num_neighbors=cfg.num_neighbors,
|
| 187 |
+
mask=True,
|
| 188 |
+
pos_images=True,
|
| 189 |
+
pos_labels=True,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if cfg.dataset_name == "voc":
|
| 193 |
+
val_loader_crop = None
|
| 194 |
+
else:
|
| 195 |
+
val_loader_crop = "center"
|
| 196 |
+
|
| 197 |
+
val_dataset = ContrastiveSegDataset(
|
| 198 |
+
pytorch_data_dir=pytorch_data_dir,
|
| 199 |
+
dataset_name=cfg.dataset_name,
|
| 200 |
+
crop_type=None,
|
| 201 |
+
image_set="val",
|
| 202 |
+
transform=get_transform(320, False, val_loader_crop),
|
| 203 |
+
target_transform=get_transform(320, True, val_loader_crop),
|
| 204 |
+
mask=True,
|
| 205 |
+
cfg=cfg,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# val_dataset = MaterializedDataset(val_dataset)
|
| 209 |
+
train_loader = DataLoader(
|
| 210 |
+
train_dataset,
|
| 211 |
+
cfg.batch_size,
|
| 212 |
+
shuffle=True,
|
| 213 |
+
num_workers=cfg.num_workers,
|
| 214 |
+
pin_memory=True,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if cfg.submitting_to_aml:
|
| 218 |
+
val_batch_size = 16
|
| 219 |
+
else:
|
| 220 |
+
val_batch_size = cfg.batch_size
|
| 221 |
+
|
| 222 |
+
val_loader = DataLoader(
|
| 223 |
+
val_dataset,
|
| 224 |
+
val_batch_size,
|
| 225 |
+
shuffle=False,
|
| 226 |
+
num_workers=cfg.num_workers,
|
| 227 |
+
pin_memory=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg)
|
| 231 |
+
|
| 232 |
+
tb_logger = TensorBoardLogger(join(log_dir, name), default_hp_metric=False)
|
| 233 |
+
|
| 234 |
+
if cfg.submitting_to_aml:
|
| 235 |
+
gpu_args = dict(gpus=1, val_check_interval=250)
|
| 236 |
+
|
| 237 |
+
if gpu_args["val_check_interval"] > len(train_loader):
|
| 238 |
+
gpu_args.pop("val_check_interval")
|
| 239 |
+
|
| 240 |
+
else:
|
| 241 |
+
gpu_args = dict(gpus=-1, accelerator="ddp", val_check_interval=cfg.val_freq)
|
| 242 |
+
# gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq)
|
| 243 |
+
|
| 244 |
+
if gpu_args["val_check_interval"] > len(train_loader) // 4:
|
| 245 |
+
gpu_args.pop("val_check_interval")
|
| 246 |
+
|
| 247 |
+
trainer = Trainer(
|
| 248 |
+
log_every_n_steps=cfg.scalar_log_freq,
|
| 249 |
+
logger=tb_logger,
|
| 250 |
+
max_steps=cfg.max_steps,
|
| 251 |
+
callbacks=[
|
| 252 |
+
ModelCheckpoint(
|
| 253 |
+
dirpath=join(checkpoint_dir, name),
|
| 254 |
+
every_n_train_steps=400,
|
| 255 |
+
save_top_k=2,
|
| 256 |
+
monitor="test/cluster/mIoU",
|
| 257 |
+
mode="max",
|
| 258 |
+
)
|
| 259 |
+
],
|
| 260 |
+
**gpu_args
|
| 261 |
+
)
|
| 262 |
+
trainer.fit(model, train_loader, val_loader)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
prep_args()
|
| 267 |
+
my_app()
|
biomap/unet.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import torchvision
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.data.sampler import Sampler
|
| 12 |
+
|
| 13 |
+
class Block(nn.Module):
|
| 14 |
+
def __init__(self, in_ch, out_ch, padding='same'):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=padding)
|
| 17 |
+
self.relu = nn.ReLU()
|
| 18 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=padding)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return self.conv2(self.relu(self.conv1(x)))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Encoder(nn.Module):
|
| 25 |
+
def __init__(self, chs=(3,32,64,128,256)):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
| 28 |
+
self.pool = nn.MaxPool2d(2)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
ftrs = []
|
| 32 |
+
for block in self.enc_blocks:
|
| 33 |
+
x = block(x)
|
| 34 |
+
ftrs.append(x)
|
| 35 |
+
x = self.pool(x)
|
| 36 |
+
return ftrs
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Decoder(nn.Module):
|
| 40 |
+
def __init__(self, chs=(256,128, 64, 32), aux_ch=70):
|
| 41 |
+
super().__init__()
|
| 42 |
+
upchs = tuple([chs[i]+aux_ch if i == 0 else chs[i] for i in range(len(chs))])
|
| 43 |
+
self.chs = chs
|
| 44 |
+
self.upchs = upchs
|
| 45 |
+
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(upchs[i], upchs[i+1], 2, 2) for i in range(len(upchs)-1)])
|
| 46 |
+
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
| 47 |
+
|
| 48 |
+
def forward(self, x, encoder_features):
|
| 49 |
+
for i in range(len(self.chs)-1):
|
| 50 |
+
# pdb.set_trace()
|
| 51 |
+
x = self.upconvs[i](x)
|
| 52 |
+
enc_ftrs = self.crop(encoder_features[i], x)
|
| 53 |
+
x = torch.cat([x, enc_ftrs], dim=1)
|
| 54 |
+
x = self.dec_blocks[i](x)
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
def crop(self, enc_ftrs, x):
|
| 58 |
+
_, _, H, W = x.shape
|
| 59 |
+
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
|
| 60 |
+
return enc_ftrs
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AuxUNet(nn.Module):
|
| 64 |
+
# UNet with auxiliary feature at the bottom
|
| 65 |
+
def __init__(self, enc_chs=(3,32,64,128,256), dec_chs=(256,128, 64, 32), aux_ch=70, num_class=7, retain_dim=False, out_sz=(224,224)):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.encoder = Encoder(enc_chs)
|
| 68 |
+
self.decoder = Decoder(dec_chs, aux_ch)
|
| 69 |
+
self.head = nn.Conv2d(dec_chs[-1], num_class, 1)
|
| 70 |
+
self.retain_dim = retain_dim
|
| 71 |
+
|
| 72 |
+
def forward(self, x, aux):
|
| 73 |
+
# aux: auxiliary feature at the bottom
|
| 74 |
+
enc_ftrs = self.encoder(x)
|
| 75 |
+
enc_ftrs[-1] = torch.cat((enc_ftrs[-1], aux), 1)
|
| 76 |
+
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
|
| 77 |
+
out = self.head(out)
|
| 78 |
+
if self.retain_dim:
|
| 79 |
+
out = F.interpolate(out, out_sz)
|
| 80 |
+
return out
|
biomap/utils.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import os
|
| 3 |
+
from os.path import join
|
| 4 |
+
import io
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch.multiprocessing
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import wget
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from scipy.optimize import linear_sum_assignment
|
| 14 |
+
from torch._six import string_classes
|
| 15 |
+
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
|
| 16 |
+
from torchmetrics import Metric
|
| 17 |
+
from torchvision import models
|
| 18 |
+
from torchvision import transforms as T
|
| 19 |
+
from torch.utils.tensorboard.summary import hparams
|
| 20 |
+
import matplotlib as mpl
|
| 21 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
| 22 |
+
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
|
| 23 |
+
class_names = (
|
| 24 |
+
"Buildings",
|
| 25 |
+
"Cultivation",
|
| 26 |
+
"Natural green",
|
| 27 |
+
"Wetland",
|
| 28 |
+
"Water",
|
| 29 |
+
"Infrastructure",
|
| 30 |
+
"Background",
|
| 31 |
+
)
|
| 32 |
+
bounds = list(np.arange(len(class_names) + 1) + 1)
|
| 33 |
+
cmap = mpl.colors.ListedColormap(colors)
|
| 34 |
+
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
|
| 35 |
+
|
| 36 |
+
def compute_biodiv_score(image):
|
| 37 |
+
"""Compute the biodiversity score of an image
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
image (_type_): _description_
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
biodiversity_score: the biodiversity score associated to the landscape of the image
|
| 44 |
+
"""
|
| 45 |
+
pix = np.array(image.getdata())
|
| 46 |
+
return np.mean(pix)
|
| 47 |
+
|
| 48 |
+
import cv2
|
| 49 |
+
def create_video(array_images, output_path="output.mp4"):
|
| 50 |
+
height, width, layers = array_images[0].shape
|
| 51 |
+
size = (width,height)
|
| 52 |
+
|
| 53 |
+
fourcc = cv2.VideoWriter_fourcc(*'VP90')
|
| 54 |
+
out = cv2.VideoWriter('output.mp4', fourcc, 2, size)
|
| 55 |
+
|
| 56 |
+
for i in range(len(array_images)):
|
| 57 |
+
out.write(array_images[i])
|
| 58 |
+
out.release()
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def transform_to_pil(outputs, alpha=0.3):
|
| 64 |
+
"""Turn an ouput into a PIL
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
outputs (_type_): _description_
|
| 68 |
+
alpha (float, optional): _description_. Defaults to 0.3.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
_type_: _description_
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
# Transform img with torch
|
| 75 |
+
img = torch.moveaxis(prep_for_plot(outputs["img"][0]), -1, 0)
|
| 76 |
+
img = T.ToPILImage()(img)
|
| 77 |
+
# Transform label by saving it then open it
|
| 78 |
+
label = outputs["linear_preds"][0].numpy()
|
| 79 |
+
# image_label = Image.fromarray(label, mode="P")
|
| 80 |
+
plt.imsave("output/label.png", label, cmap=cmap)
|
| 81 |
+
image_label = Image.open("output/label.png")
|
| 82 |
+
# Overlay labels with img wit alpha
|
| 83 |
+
background = img.convert("RGBA")
|
| 84 |
+
overlay = image_label.convert("RGBA")
|
| 85 |
+
labeled_img = Image.blend(background, overlay, alpha)
|
| 86 |
+
labeled_img = labeled_img.convert("RGB")
|
| 87 |
+
return img, image_label, labeled_img
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def prep_for_plot(img, rescale=True, resize=None):
|
| 91 |
+
if resize is not None:
|
| 92 |
+
img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear")
|
| 93 |
+
else:
|
| 94 |
+
img = img.unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0)
|
| 97 |
+
if rescale:
|
| 98 |
+
plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min())
|
| 99 |
+
return plot_img
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def add_plot(writer, name, step):
|
| 103 |
+
buf = io.BytesIO()
|
| 104 |
+
plt.savefig(buf, format='jpeg', dpi=100)
|
| 105 |
+
buf.seek(0)
|
| 106 |
+
image = Image.open(buf)
|
| 107 |
+
image = T.ToTensor()(image)
|
| 108 |
+
writer.add_image(name, image, step)
|
| 109 |
+
plt.clf()
|
| 110 |
+
plt.close()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@torch.jit.script
|
| 114 |
+
def shuffle(x):
|
| 115 |
+
return x[torch.randperm(x.shape[0])]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step):
|
| 119 |
+
exp, ssi, sei = hparams(hparam_dict, metric_dict)
|
| 120 |
+
writer.file_writer.add_summary(exp)
|
| 121 |
+
writer.file_writer.add_summary(ssi)
|
| 122 |
+
writer.file_writer.add_summary(sei)
|
| 123 |
+
for k, v in metric_dict.items():
|
| 124 |
+
writer.add_scalar(k, v, global_step)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@torch.jit.script
|
| 128 |
+
def resize(classes: torch.Tensor, size: int):
|
| 129 |
+
return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def one_hot_feats(labels, n_classes):
|
| 133 |
+
return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_model(model_type, data_dir):
|
| 137 |
+
if model_type == "robust_resnet50":
|
| 138 |
+
model = models.resnet50(pretrained=False)
|
| 139 |
+
model_file = join(data_dir, 'imagenet_l2_3_0.pt')
|
| 140 |
+
if not os.path.exists(model_file):
|
| 141 |
+
wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt",
|
| 142 |
+
model_file)
|
| 143 |
+
model_weights = torch.load(model_file)
|
| 144 |
+
model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
|
| 145 |
+
'model' in name}
|
| 146 |
+
model.load_state_dict(model_weights_modified)
|
| 147 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
| 148 |
+
elif model_type == "densecl":
|
| 149 |
+
model = models.resnet50(pretrained=False)
|
| 150 |
+
model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth')
|
| 151 |
+
if not os.path.exists(model_file):
|
| 152 |
+
wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download",
|
| 153 |
+
model_file)
|
| 154 |
+
model_weights = torch.load(model_file)
|
| 155 |
+
# model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
|
| 156 |
+
# 'model' in name}
|
| 157 |
+
model.load_state_dict(model_weights['state_dict'], strict=False)
|
| 158 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
| 159 |
+
elif model_type == "resnet50":
|
| 160 |
+
model = models.resnet50(pretrained=True)
|
| 161 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
| 162 |
+
elif model_type == "mocov2":
|
| 163 |
+
model = models.resnet50(pretrained=False)
|
| 164 |
+
model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar')
|
| 165 |
+
if not os.path.exists(model_file):
|
| 166 |
+
wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/"
|
| 167 |
+
"moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file)
|
| 168 |
+
checkpoint = torch.load(model_file)
|
| 169 |
+
# rename moco pre-trained keys
|
| 170 |
+
state_dict = checkpoint['state_dict']
|
| 171 |
+
for k in list(state_dict.keys()):
|
| 172 |
+
# retain only encoder_q up to before the embedding layer
|
| 173 |
+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
| 174 |
+
# remove prefix
|
| 175 |
+
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
| 176 |
+
# delete renamed or unused k
|
| 177 |
+
del state_dict[k]
|
| 178 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 179 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
| 180 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
| 181 |
+
elif model_type == "densenet121":
|
| 182 |
+
model = models.densenet121(pretrained=True)
|
| 183 |
+
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
|
| 184 |
+
elif model_type == "vgg11":
|
| 185 |
+
model = models.vgg11(pretrained=True)
|
| 186 |
+
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError("No model: {} found".format(model_type))
|
| 189 |
+
|
| 190 |
+
model.eval()
|
| 191 |
+
model.cuda()
|
| 192 |
+
return model
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class UnNormalize(object):
|
| 196 |
+
def __init__(self, mean, std):
|
| 197 |
+
self.mean = mean
|
| 198 |
+
self.std = std
|
| 199 |
+
|
| 200 |
+
def __call__(self, image):
|
| 201 |
+
image2 = torch.clone(image)
|
| 202 |
+
for t, m, s in zip(image2, self.mean, self.std):
|
| 203 |
+
t.mul_(s).add_(m)
|
| 204 |
+
return image2
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 208 |
+
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class ToTargetTensor(object):
|
| 212 |
+
def __call__(self, target):
|
| 213 |
+
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def prep_args():
|
| 217 |
+
import sys
|
| 218 |
+
|
| 219 |
+
old_args = sys.argv
|
| 220 |
+
new_args = [old_args.pop(0)]
|
| 221 |
+
while len(old_args) > 0:
|
| 222 |
+
arg = old_args.pop(0)
|
| 223 |
+
if len(arg.split("=")) == 2:
|
| 224 |
+
new_args.append(arg)
|
| 225 |
+
elif arg.startswith("--"):
|
| 226 |
+
new_args.append(arg[2:] + "=" + old_args.pop(0))
|
| 227 |
+
else:
|
| 228 |
+
raise ValueError("Unexpected arg style {}".format(arg))
|
| 229 |
+
sys.argv = new_args
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def get_transform(res, is_label, crop_type):
|
| 233 |
+
if crop_type == "center":
|
| 234 |
+
cropper = T.CenterCrop(res)
|
| 235 |
+
elif crop_type == "random":
|
| 236 |
+
cropper = T.RandomCrop(res)
|
| 237 |
+
elif crop_type is None:
|
| 238 |
+
cropper = T.Lambda(lambda x: x)
|
| 239 |
+
res = (res, res)
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError("Unknown Cropper {}".format(crop_type))
|
| 242 |
+
if is_label:
|
| 243 |
+
return T.Compose([T.Resize(res, Image.NEAREST),
|
| 244 |
+
cropper,
|
| 245 |
+
ToTargetTensor()])
|
| 246 |
+
else:
|
| 247 |
+
return T.Compose([T.Resize(res, Image.NEAREST),
|
| 248 |
+
cropper,
|
| 249 |
+
T.ToTensor(),
|
| 250 |
+
normalize])
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _remove_axes(ax):
|
| 254 |
+
ax.xaxis.set_major_formatter(plt.NullFormatter())
|
| 255 |
+
ax.yaxis.set_major_formatter(plt.NullFormatter())
|
| 256 |
+
ax.set_xticks([])
|
| 257 |
+
ax.set_yticks([])
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def remove_axes(axes):
|
| 261 |
+
if len(axes.shape) == 2:
|
| 262 |
+
for ax1 in axes:
|
| 263 |
+
for ax in ax1:
|
| 264 |
+
_remove_axes(ax)
|
| 265 |
+
else:
|
| 266 |
+
for ax in axes:
|
| 267 |
+
_remove_axes(ax)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class UnsupervisedMetrics(Metric):
|
| 271 |
+
def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool,
|
| 272 |
+
dist_sync_on_step=True):
|
| 273 |
+
# call `self.add_state`for every internal state that is needed for the metrics computations
|
| 274 |
+
# dist_reduce_fx indicates the function that should be used to reduce
|
| 275 |
+
# state from multiple processes
|
| 276 |
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
| 277 |
+
|
| 278 |
+
self.n_classes = n_classes
|
| 279 |
+
self.extra_clusters = extra_clusters
|
| 280 |
+
self.compute_hungarian = compute_hungarian
|
| 281 |
+
self.prefix = prefix
|
| 282 |
+
self.add_state("stats",
|
| 283 |
+
default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64),
|
| 284 |
+
dist_reduce_fx="sum")
|
| 285 |
+
|
| 286 |
+
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
actual = target.reshape(-1)
|
| 289 |
+
preds = preds.reshape(-1)
|
| 290 |
+
mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes)
|
| 291 |
+
actual = actual[mask]
|
| 292 |
+
preds = preds[mask]
|
| 293 |
+
self.stats += torch.bincount(
|
| 294 |
+
(self.n_classes + self.extra_clusters) * actual + preds,
|
| 295 |
+
minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \
|
| 296 |
+
.reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device)
|
| 297 |
+
|
| 298 |
+
def map_clusters(self, clusters):
|
| 299 |
+
if self.extra_clusters == 0:
|
| 300 |
+
return torch.tensor(self.assignments[1])[clusters]
|
| 301 |
+
else:
|
| 302 |
+
missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])))
|
| 303 |
+
cluster_to_class = self.assignments[1]
|
| 304 |
+
for missing_entry in missing:
|
| 305 |
+
if missing_entry == cluster_to_class.shape[0]:
|
| 306 |
+
cluster_to_class = np.append(cluster_to_class, -1)
|
| 307 |
+
else:
|
| 308 |
+
cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1)
|
| 309 |
+
cluster_to_class = torch.tensor(cluster_to_class)
|
| 310 |
+
return cluster_to_class[clusters]
|
| 311 |
+
|
| 312 |
+
def compute(self):
|
| 313 |
+
if self.compute_hungarian:
|
| 314 |
+
self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True)
|
| 315 |
+
# print(self.assignments)
|
| 316 |
+
if self.extra_clusters == 0:
|
| 317 |
+
self.histogram = self.stats[np.argsort(self.assignments[1]), :]
|
| 318 |
+
if self.extra_clusters > 0:
|
| 319 |
+
self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True)
|
| 320 |
+
histogram = self.stats[self.assignments_t[1], :]
|
| 321 |
+
missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))
|
| 322 |
+
new_row = self.stats[missing, :].sum(0, keepdim=True)
|
| 323 |
+
histogram = torch.cat([histogram, new_row], axis=0)
|
| 324 |
+
new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device)
|
| 325 |
+
self.histogram = torch.cat([histogram, new_col], axis=1)
|
| 326 |
+
else:
|
| 327 |
+
self.assignments = (torch.arange(self.n_classes).unsqueeze(1),
|
| 328 |
+
torch.arange(self.n_classes).unsqueeze(1))
|
| 329 |
+
self.histogram = self.stats
|
| 330 |
+
|
| 331 |
+
tp = torch.diag(self.histogram)
|
| 332 |
+
fp = torch.sum(self.histogram, dim=0) - tp
|
| 333 |
+
fn = torch.sum(self.histogram, dim=1) - tp
|
| 334 |
+
|
| 335 |
+
iou = tp / (tp + fp + fn)
|
| 336 |
+
prc = tp / (tp + fn)
|
| 337 |
+
opc = torch.sum(tp) / torch.sum(self.histogram)
|
| 338 |
+
|
| 339 |
+
metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(),
|
| 340 |
+
self.prefix + "Accuracy": opc.item()}
|
| 341 |
+
return {k: 100 * v for k, v in metric_dict.items()}
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def flexible_collate(batch):
|
| 345 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
| 346 |
+
|
| 347 |
+
elem = batch[0]
|
| 348 |
+
elem_type = type(elem)
|
| 349 |
+
if isinstance(elem, torch.Tensor):
|
| 350 |
+
out = None
|
| 351 |
+
if torch.utils.data.get_worker_info() is not None:
|
| 352 |
+
# If we're in a background process, concatenate directly into a
|
| 353 |
+
# shared memory tensor to avoid an extra copy
|
| 354 |
+
numel = sum([x.numel() for x in batch])
|
| 355 |
+
storage = elem.storage()._new_shared(numel)
|
| 356 |
+
out = elem.new(storage)
|
| 357 |
+
try:
|
| 358 |
+
return torch.stack(batch, 0, out=out)
|
| 359 |
+
except RuntimeError:
|
| 360 |
+
return batch
|
| 361 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
| 362 |
+
and elem_type.__name__ != 'string_':
|
| 363 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
| 364 |
+
# array of string classes and object
|
| 365 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
| 366 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
| 367 |
+
|
| 368 |
+
return flexible_collate([torch.as_tensor(b) for b in batch])
|
| 369 |
+
elif elem.shape == (): # scalars
|
| 370 |
+
return torch.as_tensor(batch)
|
| 371 |
+
elif isinstance(elem, float):
|
| 372 |
+
return torch.tensor(batch, dtype=torch.float64)
|
| 373 |
+
elif isinstance(elem, int):
|
| 374 |
+
return torch.tensor(batch)
|
| 375 |
+
elif isinstance(elem, string_classes):
|
| 376 |
+
return batch
|
| 377 |
+
elif isinstance(elem, collections.abc.Mapping):
|
| 378 |
+
return {key: flexible_collate([d[key] for d in batch]) for key in elem}
|
| 379 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
| 380 |
+
return elem_type(*(flexible_collate(samples) for samples in zip(*batch)))
|
| 381 |
+
elif isinstance(elem, collections.abc.Sequence):
|
| 382 |
+
# check to make sure that the elements in batch have consistent size
|
| 383 |
+
it = iter(batch)
|
| 384 |
+
elem_size = len(next(it))
|
| 385 |
+
if not all(len(elem) == elem_size for elem in it):
|
| 386 |
+
raise RuntimeError('each element in list of batch should be of equal size')
|
| 387 |
+
transposed = zip(*batch)
|
| 388 |
+
return [flexible_collate(samples) for samples in transposed]
|
| 389 |
+
|
| 390 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
biomap/utils_gee.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import requests
|
| 3 |
+
import ee
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
#Initialize
|
| 8 |
+
service_account = '[email protected]'
|
| 9 |
+
credentials = ee.ServiceAccountCredentials(service_account, '.private-key.json')
|
| 10 |
+
ee.Initialize(credentials)
|
| 11 |
+
|
| 12 |
+
#delete clouds
|
| 13 |
+
def maskS2clouds(image):
|
| 14 |
+
qa = image.select('QA60');
|
| 15 |
+
|
| 16 |
+
# // Bits 10 and 11 are clouds and cirrus, respectively.
|
| 17 |
+
cloudBitMask = 1 << 10;
|
| 18 |
+
cirrusBitMask = 1 << 11;
|
| 19 |
+
|
| 20 |
+
# // Both flags should be set to zero, indicating clear conditions.
|
| 21 |
+
mask = (qa.bitwiseAnd(cloudBitMask).eq(0))and(qa.bitwiseAnd(cirrusBitMask).eq(0))
|
| 22 |
+
|
| 23 |
+
return image.updateMask(mask).divide(10000);
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#find ee_img
|
| 27 |
+
def extract_ee_img(location,start_date,end_date, width = 0.01 , len = 0.01) :
|
| 28 |
+
"""Extract the earth engine image
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
location (list[float]):
|
| 32 |
+
start_date (str): the start date for finding an image
|
| 33 |
+
end_date (str): the end date for finding an image
|
| 34 |
+
width (float, optional): _description_. Defaults to 0.01.
|
| 35 |
+
len (float, optional): _description_. Defaults to 0.01.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
_type_: _description_
|
| 39 |
+
"""
|
| 40 |
+
# define the polygone
|
| 41 |
+
polygone =[[[float(location[0])-0.01,float(location[1])+0.01],
|
| 42 |
+
[float(location[0])-0.01,float(location[1])-0.01],
|
| 43 |
+
[float(location[0])+0.01,float(location[1])-0.01],
|
| 44 |
+
[float(location[0])+0.01,float(location[1])+0.01],
|
| 45 |
+
]]
|
| 46 |
+
|
| 47 |
+
#define the ee geometry
|
| 48 |
+
geometry = ee.Geometry.Polygon(polygone, None, False);
|
| 49 |
+
|
| 50 |
+
#extract the dataset
|
| 51 |
+
dataset = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')\
|
| 52 |
+
.filterDate(start_date, end_date)\
|
| 53 |
+
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE',1))\
|
| 54 |
+
.map(maskS2clouds)
|
| 55 |
+
return dataset.mean(), geometry
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Get URL
|
| 60 |
+
def get_url(ee_img, geometry, scale=5):
|
| 61 |
+
"""Get the url of a dataset and a geometry
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
ee_img (ee.ImageCollection: meta data on the image
|
| 65 |
+
geometry (ee.Geometry.Polygon): geometry of the desired landscape
|
| 66 |
+
scale (int, optional): _description_. Defaults to 5.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
str: the url to use to ask the server
|
| 70 |
+
"""
|
| 71 |
+
region = geometry
|
| 72 |
+
|
| 73 |
+
# collectionList = ee_img.toList(ee_img.size())
|
| 74 |
+
# collectionSize = collectionList.size().getInfo()
|
| 75 |
+
# for i in xrange(collectionSize):
|
| 76 |
+
# ee.batch.Export.image.toDrive(
|
| 77 |
+
# image = ee.Image(collectionList.get(i)).clip(rectangle),
|
| 78 |
+
# fileNamePrefix = 'foo' + str(i + 1),
|
| 79 |
+
# dimensions = '128x128').start()
|
| 80 |
+
|
| 81 |
+
url = ee_img.getDownloadURL({
|
| 82 |
+
# 'min': 0.0,
|
| 83 |
+
# 'max': 0.3,
|
| 84 |
+
'bands': ['B4', 'B3', 'B2'],
|
| 85 |
+
'region' : region,
|
| 86 |
+
'scale' : scale,
|
| 87 |
+
'format' : 'NPY'
|
| 88 |
+
})
|
| 89 |
+
|
| 90 |
+
return url
|
| 91 |
+
|
| 92 |
+
def extract_np_from_url(url):
|
| 93 |
+
"""extract a numpy array based on a url
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
url (str): _description_
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
numpyarray: response from earth engine as numpy
|
| 100 |
+
"""
|
| 101 |
+
#get the response from url
|
| 102 |
+
response = requests.get(url)
|
| 103 |
+
|
| 104 |
+
#transform it into numpy
|
| 105 |
+
data = np.load(io.BytesIO(response.content))
|
| 106 |
+
|
| 107 |
+
#transform numpy of tuples to 3D numpy
|
| 108 |
+
temp1 = []
|
| 109 |
+
|
| 110 |
+
for x in data:
|
| 111 |
+
temp2 = []
|
| 112 |
+
for y in x :
|
| 113 |
+
temp2.append([z for z in y])
|
| 114 |
+
temp1.append(temp2)
|
| 115 |
+
|
| 116 |
+
data = np.array(temp1)
|
| 117 |
+
|
| 118 |
+
return data
|
| 119 |
+
|
| 120 |
+
#Fonction globale
|
| 121 |
+
def extract_img(location,start_date,end_date, width = 0.01 , len = 0.01,scale=5):
|
| 122 |
+
"""Extract an image of the landscape at the selected longitude and latitude with the selected width and length
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
location (list[float]): [latitude of the center of the landscape, longitude of the center of the landscape]
|
| 126 |
+
start_date (str): the start date
|
| 127 |
+
end_date (str): _description_
|
| 128 |
+
width (float, optional): _description_. Defaults to 0.01.
|
| 129 |
+
len (float, optional): _description_. Defaults to 0.01.
|
| 130 |
+
scale (int, optional): _description_. Defaults to 5.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
img: image as numpy array
|
| 134 |
+
"""
|
| 135 |
+
ee_img, geometry = extract_ee_img(location, width,start_date,end_date , len)
|
| 136 |
+
url = get_url(ee_img, geometry, scale)
|
| 137 |
+
img = extract_np_from_url(url)
|
| 138 |
+
|
| 139 |
+
return img
|
| 140 |
+
|
| 141 |
+
# transform img from numpy to PIL
|
| 142 |
+
def transform_ee_img(img, min = 0, max=0.3):
|
| 143 |
+
"""Transform an img from numpy to PIL
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
img (numpy array): the original image as a numpy array
|
| 147 |
+
min (int, optional): _description_. Defaults to 0.
|
| 148 |
+
max (float, optional): _description_. Defaults to 0.3.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
img_test: a PIL image
|
| 152 |
+
"""
|
| 153 |
+
img_test=img
|
| 154 |
+
img_test=np.minimum(img_test*255/max,np.ones(img.shape)*255)
|
| 155 |
+
img_test=np.uint8((np.rint(img_test)).astype(int))
|
| 156 |
+
plt.imshow(img_test)
|
| 157 |
+
return img_test
|
poetry.lock
ADDED
|
@@ -0,0 +1,1625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[[package]]
|
| 2 |
+
name = "absl-py"
|
| 3 |
+
version = "1.4.0"
|
| 4 |
+
description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
|
| 5 |
+
category = "main"
|
| 6 |
+
optional = false
|
| 7 |
+
python-versions = ">=3.6"
|
| 8 |
+
|
| 9 |
+
[[package]]
|
| 10 |
+
name = "aiofiles"
|
| 11 |
+
version = "23.1.0"
|
| 12 |
+
description = "File support for asyncio."
|
| 13 |
+
category = "main"
|
| 14 |
+
optional = false
|
| 15 |
+
python-versions = ">=3.7,<4.0"
|
| 16 |
+
|
| 17 |
+
[[package]]
|
| 18 |
+
name = "aiohttp"
|
| 19 |
+
version = "3.8.4"
|
| 20 |
+
description = "Async http client/server framework (asyncio)"
|
| 21 |
+
category = "main"
|
| 22 |
+
optional = false
|
| 23 |
+
python-versions = ">=3.6"
|
| 24 |
+
|
| 25 |
+
[package.dependencies]
|
| 26 |
+
aiosignal = ">=1.1.2"
|
| 27 |
+
async-timeout = ">=4.0.0a3,<5.0"
|
| 28 |
+
attrs = ">=17.3.0"
|
| 29 |
+
charset-normalizer = ">=2.0,<4.0"
|
| 30 |
+
frozenlist = ">=1.1.1"
|
| 31 |
+
multidict = ">=4.5,<7.0"
|
| 32 |
+
yarl = ">=1.0,<2.0"
|
| 33 |
+
|
| 34 |
+
[package.extras]
|
| 35 |
+
speedups = ["aiodns", "brotli", "cchardet"]
|
| 36 |
+
|
| 37 |
+
[[package]]
|
| 38 |
+
name = "aiosignal"
|
| 39 |
+
version = "1.3.1"
|
| 40 |
+
description = "aiosignal: a list of registered asynchronous callbacks"
|
| 41 |
+
category = "main"
|
| 42 |
+
optional = false
|
| 43 |
+
python-versions = ">=3.7"
|
| 44 |
+
|
| 45 |
+
[package.dependencies]
|
| 46 |
+
frozenlist = ">=1.1.0"
|
| 47 |
+
|
| 48 |
+
[[package]]
|
| 49 |
+
name = "altair"
|
| 50 |
+
version = "4.2.2"
|
| 51 |
+
description = "Altair: A declarative statistical visualization library for Python."
|
| 52 |
+
category = "main"
|
| 53 |
+
optional = false
|
| 54 |
+
python-versions = ">=3.7"
|
| 55 |
+
|
| 56 |
+
[package.dependencies]
|
| 57 |
+
entrypoints = "*"
|
| 58 |
+
jinja2 = "*"
|
| 59 |
+
jsonschema = ">=3.0"
|
| 60 |
+
numpy = "*"
|
| 61 |
+
pandas = ">=0.18"
|
| 62 |
+
toolz = "*"
|
| 63 |
+
|
| 64 |
+
[package.extras]
|
| 65 |
+
dev = ["black", "docutils", "ipython", "flake8", "pytest", "sphinx", "mistune (<2.0.0)", "m2r", "vega-datasets", "recommonmark"]
|
| 66 |
+
|
| 67 |
+
[[package]]
|
| 68 |
+
name = "antlr4-python3-runtime"
|
| 69 |
+
version = "4.9.3"
|
| 70 |
+
description = "ANTLR 4.9.3 runtime for Python 3.7"
|
| 71 |
+
category = "main"
|
| 72 |
+
optional = false
|
| 73 |
+
python-versions = "*"
|
| 74 |
+
|
| 75 |
+
[[package]]
|
| 76 |
+
name = "anyio"
|
| 77 |
+
version = "3.6.2"
|
| 78 |
+
description = "High level compatibility layer for multiple asynchronous event loop implementations"
|
| 79 |
+
category = "main"
|
| 80 |
+
optional = false
|
| 81 |
+
python-versions = ">=3.6.2"
|
| 82 |
+
|
| 83 |
+
[package.dependencies]
|
| 84 |
+
idna = ">=2.8"
|
| 85 |
+
sniffio = ">=1.1"
|
| 86 |
+
|
| 87 |
+
[package.extras]
|
| 88 |
+
doc = ["packaging", "sphinx-rtd-theme", "sphinx-autodoc-typehints (>=1.2.0)"]
|
| 89 |
+
test = ["coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "contextlib2", "uvloop (<0.15)", "mock (>=4)", "uvloop (>=0.15)"]
|
| 90 |
+
trio = ["trio (>=0.16,<0.22)"]
|
| 91 |
+
|
| 92 |
+
[[package]]
|
| 93 |
+
name = "async-timeout"
|
| 94 |
+
version = "4.0.2"
|
| 95 |
+
description = "Timeout context manager for asyncio programs"
|
| 96 |
+
category = "main"
|
| 97 |
+
optional = false
|
| 98 |
+
python-versions = ">=3.6"
|
| 99 |
+
|
| 100 |
+
[[package]]
|
| 101 |
+
name = "attrs"
|
| 102 |
+
version = "19.3.0"
|
| 103 |
+
description = "Classes Without Boilerplate"
|
| 104 |
+
category = "main"
|
| 105 |
+
optional = false
|
| 106 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
| 107 |
+
|
| 108 |
+
[package.extras]
|
| 109 |
+
azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "pytest-azurepipelines"]
|
| 110 |
+
dev = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "pre-commit"]
|
| 111 |
+
docs = ["sphinx", "zope.interface"]
|
| 112 |
+
tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"]
|
| 113 |
+
|
| 114 |
+
[[package]]
|
| 115 |
+
name = "cachetools"
|
| 116 |
+
version = "5.3.0"
|
| 117 |
+
description = "Extensible memoizing collections and decorators"
|
| 118 |
+
category = "main"
|
| 119 |
+
optional = false
|
| 120 |
+
python-versions = "~=3.7"
|
| 121 |
+
|
| 122 |
+
[[package]]
|
| 123 |
+
name = "certifi"
|
| 124 |
+
version = "2022.12.7"
|
| 125 |
+
description = "Python package for providing Mozilla's CA Bundle."
|
| 126 |
+
category = "main"
|
| 127 |
+
optional = false
|
| 128 |
+
python-versions = ">=3.6"
|
| 129 |
+
|
| 130 |
+
[[package]]
|
| 131 |
+
name = "charset-normalizer"
|
| 132 |
+
version = "3.1.0"
|
| 133 |
+
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
| 134 |
+
category = "main"
|
| 135 |
+
optional = false
|
| 136 |
+
python-versions = ">=3.7.0"
|
| 137 |
+
|
| 138 |
+
[[package]]
|
| 139 |
+
name = "click"
|
| 140 |
+
version = "8.1.3"
|
| 141 |
+
description = "Composable command line interface toolkit"
|
| 142 |
+
category = "main"
|
| 143 |
+
optional = false
|
| 144 |
+
python-versions = ">=3.7"
|
| 145 |
+
|
| 146 |
+
[package.dependencies]
|
| 147 |
+
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
| 148 |
+
|
| 149 |
+
[[package]]
|
| 150 |
+
name = "colorama"
|
| 151 |
+
version = "0.4.6"
|
| 152 |
+
description = "Cross-platform colored terminal text."
|
| 153 |
+
category = "main"
|
| 154 |
+
optional = false
|
| 155 |
+
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
| 156 |
+
|
| 157 |
+
[[package]]
|
| 158 |
+
name = "contourpy"
|
| 159 |
+
version = "1.0.7"
|
| 160 |
+
description = "Python library for calculating contours of 2D quadrilateral grids"
|
| 161 |
+
category = "main"
|
| 162 |
+
optional = false
|
| 163 |
+
python-versions = ">=3.8"
|
| 164 |
+
|
| 165 |
+
[package.dependencies]
|
| 166 |
+
numpy = ">=1.16"
|
| 167 |
+
|
| 168 |
+
[package.extras]
|
| 169 |
+
bokeh = ["bokeh", "chromedriver", "selenium"]
|
| 170 |
+
docs = ["furo", "sphinx-copybutton"]
|
| 171 |
+
mypy = ["contourpy", "docutils-stubs", "mypy (==0.991)", "types-pillow"]
|
| 172 |
+
test = ["matplotlib", "pillow", "pytest"]
|
| 173 |
+
test-no-images = ["pytest"]
|
| 174 |
+
|
| 175 |
+
[[package]]
|
| 176 |
+
name = "cycler"
|
| 177 |
+
version = "0.11.0"
|
| 178 |
+
description = "Composable style cycles"
|
| 179 |
+
category = "main"
|
| 180 |
+
optional = false
|
| 181 |
+
python-versions = ">=3.6"
|
| 182 |
+
|
| 183 |
+
[[package]]
|
| 184 |
+
name = "earthengine-api"
|
| 185 |
+
version = "0.1.338"
|
| 186 |
+
description = "Earth Engine Python API"
|
| 187 |
+
category = "main"
|
| 188 |
+
optional = false
|
| 189 |
+
python-versions = "*"
|
| 190 |
+
|
| 191 |
+
[package.dependencies]
|
| 192 |
+
google-api-python-client = ">=1.12.1"
|
| 193 |
+
google-auth = ">=1.4.1"
|
| 194 |
+
google-auth-httplib2 = ">=0.0.3"
|
| 195 |
+
google-cloud-storage = "*"
|
| 196 |
+
httplib2 = ">=0.9.2,<1dev"
|
| 197 |
+
requests = "*"
|
| 198 |
+
|
| 199 |
+
[[package]]
|
| 200 |
+
name = "ee-extra"
|
| 201 |
+
version = "0.0.15"
|
| 202 |
+
description = "A ninja Python package behind rgee, rgeeExtra and eemont."
|
| 203 |
+
category = "main"
|
| 204 |
+
optional = false
|
| 205 |
+
python-versions = "*"
|
| 206 |
+
|
| 207 |
+
[package.dependencies]
|
| 208 |
+
earthengine-api = "*"
|
| 209 |
+
|
| 210 |
+
[[package]]
|
| 211 |
+
name = "entrypoints"
|
| 212 |
+
version = "0.4"
|
| 213 |
+
description = "Discover and load entry points from installed packages."
|
| 214 |
+
category = "main"
|
| 215 |
+
optional = false
|
| 216 |
+
python-versions = ">=3.6"
|
| 217 |
+
|
| 218 |
+
[[package]]
|
| 219 |
+
name = "fastapi"
|
| 220 |
+
version = "0.95.1"
|
| 221 |
+
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
| 222 |
+
category = "main"
|
| 223 |
+
optional = false
|
| 224 |
+
python-versions = ">=3.7"
|
| 225 |
+
|
| 226 |
+
[package.dependencies]
|
| 227 |
+
pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0"
|
| 228 |
+
starlette = ">=0.26.1,<0.27.0"
|
| 229 |
+
|
| 230 |
+
[package.extras]
|
| 231 |
+
all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
| 232 |
+
dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>=0.12.0,<0.21.0)"]
|
| 233 |
+
doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"]
|
| 234 |
+
test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"]
|
| 235 |
+
|
| 236 |
+
[[package]]
|
| 237 |
+
name = "ffmpy"
|
| 238 |
+
version = "0.3.0"
|
| 239 |
+
description = "A simple Python wrapper for ffmpeg"
|
| 240 |
+
category = "main"
|
| 241 |
+
optional = false
|
| 242 |
+
python-versions = "*"
|
| 243 |
+
|
| 244 |
+
[[package]]
|
| 245 |
+
name = "filelock"
|
| 246 |
+
version = "3.11.0"
|
| 247 |
+
description = "A platform independent file lock."
|
| 248 |
+
category = "main"
|
| 249 |
+
optional = false
|
| 250 |
+
python-versions = ">=3.7"
|
| 251 |
+
|
| 252 |
+
[package.extras]
|
| 253 |
+
docs = ["furo (>=2023.3.27)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)", "sphinx (>=6.1.3)"]
|
| 254 |
+
testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)", "pytest (>=7.2.2)"]
|
| 255 |
+
|
| 256 |
+
[[package]]
|
| 257 |
+
name = "fonttools"
|
| 258 |
+
version = "4.39.3"
|
| 259 |
+
description = "Tools to manipulate font files"
|
| 260 |
+
category = "main"
|
| 261 |
+
optional = false
|
| 262 |
+
python-versions = ">=3.8"
|
| 263 |
+
|
| 264 |
+
[package.extras]
|
| 265 |
+
all = ["fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "zopfli (>=0.1.4)", "lz4 (>=1.7.4.2)", "matplotlib", "sympy", "skia-pathops (>=0.5.0)", "uharfbuzz (>=0.23.0)", "brotlicffi (>=0.8.0)", "scipy", "brotli (>=1.0.1)", "munkres", "unicodedata2 (>=15.0.0)", "xattr"]
|
| 266 |
+
graphite = ["lz4 (>=1.7.4.2)"]
|
| 267 |
+
interpolatable = ["scipy", "munkres"]
|
| 268 |
+
lxml = ["lxml (>=4.0,<5)"]
|
| 269 |
+
pathops = ["skia-pathops (>=0.5.0)"]
|
| 270 |
+
plot = ["matplotlib"]
|
| 271 |
+
repacker = ["uharfbuzz (>=0.23.0)"]
|
| 272 |
+
symfont = ["sympy"]
|
| 273 |
+
type1 = ["xattr"]
|
| 274 |
+
ufo = ["fs (>=2.2.0,<3)"]
|
| 275 |
+
unicode = ["unicodedata2 (>=15.0.0)"]
|
| 276 |
+
woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"]
|
| 277 |
+
|
| 278 |
+
[[package]]
|
| 279 |
+
name = "frozenlist"
|
| 280 |
+
version = "1.3.3"
|
| 281 |
+
description = "A list-like structure which implements collections.abc.MutableSequence"
|
| 282 |
+
category = "main"
|
| 283 |
+
optional = false
|
| 284 |
+
python-versions = ">=3.7"
|
| 285 |
+
|
| 286 |
+
[[package]]
|
| 287 |
+
name = "fsspec"
|
| 288 |
+
version = "2023.4.0"
|
| 289 |
+
description = "File-system specification"
|
| 290 |
+
category = "main"
|
| 291 |
+
optional = false
|
| 292 |
+
python-versions = ">=3.8"
|
| 293 |
+
|
| 294 |
+
[package.dependencies]
|
| 295 |
+
aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
|
| 296 |
+
requests = {version = "*", optional = true, markers = "extra == \"http\""}
|
| 297 |
+
|
| 298 |
+
[package.extras]
|
| 299 |
+
abfs = ["adlfs"]
|
| 300 |
+
adl = ["adlfs"]
|
| 301 |
+
arrow = ["pyarrow (>=1)"]
|
| 302 |
+
dask = ["dask", "distributed"]
|
| 303 |
+
devel = ["pytest", "pytest-cov"]
|
| 304 |
+
dropbox = ["dropboxdrivefs", "requests", "dropbox"]
|
| 305 |
+
full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
|
| 306 |
+
fuse = ["fusepy"]
|
| 307 |
+
gcs = ["gcsfs"]
|
| 308 |
+
git = ["pygit2"]
|
| 309 |
+
github = ["requests"]
|
| 310 |
+
gs = ["gcsfs"]
|
| 311 |
+
gui = ["panel"]
|
| 312 |
+
hdfs = ["pyarrow (>=1)"]
|
| 313 |
+
http = ["requests", "aiohttp (!=4.0.0a0,!=4.0.0a1)"]
|
| 314 |
+
libarchive = ["libarchive-c"]
|
| 315 |
+
oci = ["ocifs"]
|
| 316 |
+
s3 = ["s3fs"]
|
| 317 |
+
sftp = ["paramiko"]
|
| 318 |
+
smb = ["smbprotocol"]
|
| 319 |
+
ssh = ["paramiko"]
|
| 320 |
+
tqdm = ["tqdm"]
|
| 321 |
+
|
| 322 |
+
[[package]]
|
| 323 |
+
name = "google-api-core"
|
| 324 |
+
version = "2.11.0"
|
| 325 |
+
description = "Google API client core library"
|
| 326 |
+
category = "main"
|
| 327 |
+
optional = false
|
| 328 |
+
python-versions = ">=3.7"
|
| 329 |
+
|
| 330 |
+
[package.dependencies]
|
| 331 |
+
google-auth = ">=2.14.1,<3.0dev"
|
| 332 |
+
googleapis-common-protos = ">=1.56.2,<2.0dev"
|
| 333 |
+
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
| 334 |
+
requests = ">=2.18.0,<3.0.0dev"
|
| 335 |
+
|
| 336 |
+
[package.extras]
|
| 337 |
+
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio-status (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.49.1,<2.0dev)"]
|
| 338 |
+
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
|
| 339 |
+
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
|
| 340 |
+
|
| 341 |
+
[[package]]
|
| 342 |
+
name = "google-api-python-client"
|
| 343 |
+
version = "2.85.0"
|
| 344 |
+
description = "Google API Client Library for Python"
|
| 345 |
+
category = "main"
|
| 346 |
+
optional = false
|
| 347 |
+
python-versions = ">=3.7"
|
| 348 |
+
|
| 349 |
+
[package.dependencies]
|
| 350 |
+
google-api-core = ">=1.31.5,<2.0.0 || >2.3.0,<3.0.0dev"
|
| 351 |
+
google-auth = ">=1.19.0,<3.0.0dev"
|
| 352 |
+
google-auth-httplib2 = ">=0.1.0"
|
| 353 |
+
httplib2 = ">=0.15.0,<1dev"
|
| 354 |
+
uritemplate = ">=3.0.1,<5"
|
| 355 |
+
|
| 356 |
+
[[package]]
|
| 357 |
+
name = "google-auth"
|
| 358 |
+
version = "2.17.3"
|
| 359 |
+
description = "Google Authentication Library"
|
| 360 |
+
category = "main"
|
| 361 |
+
optional = false
|
| 362 |
+
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
|
| 363 |
+
|
| 364 |
+
[package.dependencies]
|
| 365 |
+
cachetools = ">=2.0.0,<6.0"
|
| 366 |
+
pyasn1-modules = ">=0.2.1"
|
| 367 |
+
rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""}
|
| 368 |
+
six = ">=1.9.0"
|
| 369 |
+
|
| 370 |
+
[package.extras]
|
| 371 |
+
aiohttp = ["requests (>=2.20.0,<3.0.0dev)", "aiohttp (>=3.6.2,<4.0.0dev)"]
|
| 372 |
+
enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
|
| 373 |
+
pyopenssl = ["pyopenssl (>=20.0.0)", "cryptography (>=38.0.3)"]
|
| 374 |
+
reauth = ["pyu2f (>=0.1.5)"]
|
| 375 |
+
requests = ["requests (>=2.20.0,<3.0.0dev)"]
|
| 376 |
+
|
| 377 |
+
[[package]]
|
| 378 |
+
name = "google-auth-httplib2"
|
| 379 |
+
version = "0.1.0"
|
| 380 |
+
description = "Google Authentication Library: httplib2 transport"
|
| 381 |
+
category = "main"
|
| 382 |
+
optional = false
|
| 383 |
+
python-versions = "*"
|
| 384 |
+
|
| 385 |
+
[package.dependencies]
|
| 386 |
+
google-auth = "*"
|
| 387 |
+
httplib2 = ">=0.15.0"
|
| 388 |
+
six = "*"
|
| 389 |
+
|
| 390 |
+
[[package]]
|
| 391 |
+
name = "google-auth-oauthlib"
|
| 392 |
+
version = "0.4.6"
|
| 393 |
+
description = "Google Authentication Library"
|
| 394 |
+
category = "main"
|
| 395 |
+
optional = false
|
| 396 |
+
python-versions = ">=3.6"
|
| 397 |
+
|
| 398 |
+
[package.dependencies]
|
| 399 |
+
google-auth = ">=1.0.0"
|
| 400 |
+
requests-oauthlib = ">=0.7.0"
|
| 401 |
+
|
| 402 |
+
[package.extras]
|
| 403 |
+
tool = ["click (>=6.0.0)"]
|
| 404 |
+
|
| 405 |
+
[[package]]
|
| 406 |
+
name = "google-cloud-core"
|
| 407 |
+
version = "2.3.2"
|
| 408 |
+
description = "Google Cloud API client core library"
|
| 409 |
+
category = "main"
|
| 410 |
+
optional = false
|
| 411 |
+
python-versions = ">=3.7"
|
| 412 |
+
|
| 413 |
+
[package.dependencies]
|
| 414 |
+
google-api-core = ">=1.31.6,<2.0.0 || >2.3.0,<3.0.0dev"
|
| 415 |
+
google-auth = ">=1.25.0,<3.0dev"
|
| 416 |
+
|
| 417 |
+
[package.extras]
|
| 418 |
+
grpc = ["grpcio (>=1.38.0,<2.0dev)"]
|
| 419 |
+
|
| 420 |
+
[[package]]
|
| 421 |
+
name = "google-cloud-storage"
|
| 422 |
+
version = "2.8.0"
|
| 423 |
+
description = "Google Cloud Storage API client library"
|
| 424 |
+
category = "main"
|
| 425 |
+
optional = false
|
| 426 |
+
python-versions = ">=3.7"
|
| 427 |
+
|
| 428 |
+
[package.dependencies]
|
| 429 |
+
google-api-core = ">=1.31.5,<2.0.0 || >2.3.0,<3.0.0dev"
|
| 430 |
+
google-auth = ">=1.25.0,<3.0dev"
|
| 431 |
+
google-cloud-core = ">=2.3.0,<3.0dev"
|
| 432 |
+
google-resumable-media = ">=2.3.2"
|
| 433 |
+
requests = ">=2.18.0,<3.0.0dev"
|
| 434 |
+
|
| 435 |
+
[package.extras]
|
| 436 |
+
protobuf = ["protobuf (<5.0.0dev)"]
|
| 437 |
+
|
| 438 |
+
[[package]]
|
| 439 |
+
name = "google-crc32c"
|
| 440 |
+
version = "1.5.0"
|
| 441 |
+
description = "A python wrapper of the C library 'Google CRC32C'"
|
| 442 |
+
category = "main"
|
| 443 |
+
optional = false
|
| 444 |
+
python-versions = ">=3.7"
|
| 445 |
+
|
| 446 |
+
[package.extras]
|
| 447 |
+
testing = ["pytest"]
|
| 448 |
+
|
| 449 |
+
[[package]]
|
| 450 |
+
name = "google-resumable-media"
|
| 451 |
+
version = "2.4.1"
|
| 452 |
+
description = "Utilities for Google Media Downloads and Resumable Uploads"
|
| 453 |
+
category = "main"
|
| 454 |
+
optional = false
|
| 455 |
+
python-versions = ">= 3.7"
|
| 456 |
+
|
| 457 |
+
[package.dependencies]
|
| 458 |
+
google-crc32c = ">=1.0,<2.0dev"
|
| 459 |
+
|
| 460 |
+
[package.extras]
|
| 461 |
+
aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)"]
|
| 462 |
+
requests = ["requests (>=2.18.0,<3.0.0dev)"]
|
| 463 |
+
|
| 464 |
+
[[package]]
|
| 465 |
+
name = "googleapis-common-protos"
|
| 466 |
+
version = "1.59.0"
|
| 467 |
+
description = "Common protobufs used in Google APIs"
|
| 468 |
+
category = "main"
|
| 469 |
+
optional = false
|
| 470 |
+
python-versions = ">=3.7"
|
| 471 |
+
|
| 472 |
+
[package.dependencies]
|
| 473 |
+
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
| 474 |
+
|
| 475 |
+
[package.extras]
|
| 476 |
+
grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
|
| 477 |
+
|
| 478 |
+
[[package]]
|
| 479 |
+
name = "gradio"
|
| 480 |
+
version = "3.27.0"
|
| 481 |
+
description = "Python library for easily interacting with trained machine learning models"
|
| 482 |
+
category = "main"
|
| 483 |
+
optional = false
|
| 484 |
+
python-versions = ">=3.7"
|
| 485 |
+
|
| 486 |
+
[package.dependencies]
|
| 487 |
+
aiofiles = "*"
|
| 488 |
+
aiohttp = "*"
|
| 489 |
+
altair = ">=4.2.0"
|
| 490 |
+
fastapi = "*"
|
| 491 |
+
ffmpy = "*"
|
| 492 |
+
gradio-client = ">=0.1.3"
|
| 493 |
+
httpx = "*"
|
| 494 |
+
huggingface-hub = ">=0.13.0"
|
| 495 |
+
jinja2 = "*"
|
| 496 |
+
markdown-it-py = {version = ">=2.0.0", extras = ["linkify"]}
|
| 497 |
+
markupsafe = "*"
|
| 498 |
+
matplotlib = "*"
|
| 499 |
+
mdit-py-plugins = "<=0.3.3"
|
| 500 |
+
numpy = "*"
|
| 501 |
+
orjson = "*"
|
| 502 |
+
pandas = "*"
|
| 503 |
+
pillow = "*"
|
| 504 |
+
pydantic = "*"
|
| 505 |
+
pydub = "*"
|
| 506 |
+
python-multipart = "*"
|
| 507 |
+
pyyaml = "*"
|
| 508 |
+
requests = "*"
|
| 509 |
+
semantic-version = "*"
|
| 510 |
+
typing-extensions = "*"
|
| 511 |
+
uvicorn = "*"
|
| 512 |
+
websockets = ">=10.0"
|
| 513 |
+
|
| 514 |
+
[[package]]
|
| 515 |
+
name = "gradio-client"
|
| 516 |
+
version = "0.1.3"
|
| 517 |
+
description = "Python library for easily interacting with trained machine learning models"
|
| 518 |
+
category = "main"
|
| 519 |
+
optional = false
|
| 520 |
+
python-versions = ">=3.7"
|
| 521 |
+
|
| 522 |
+
[package.dependencies]
|
| 523 |
+
fsspec = "*"
|
| 524 |
+
httpx = "*"
|
| 525 |
+
huggingface-hub = ">=0.13.0"
|
| 526 |
+
packaging = "*"
|
| 527 |
+
requests = "*"
|
| 528 |
+
typing-extensions = "*"
|
| 529 |
+
websockets = "*"
|
| 530 |
+
|
| 531 |
+
[[package]]
|
| 532 |
+
name = "grpcio"
|
| 533 |
+
version = "1.53.0"
|
| 534 |
+
description = "HTTP/2-based RPC framework"
|
| 535 |
+
category = "main"
|
| 536 |
+
optional = false
|
| 537 |
+
python-versions = ">=3.7"
|
| 538 |
+
|
| 539 |
+
[package.extras]
|
| 540 |
+
protobuf = ["grpcio-tools (>=1.53.0)"]
|
| 541 |
+
|
| 542 |
+
[[package]]
|
| 543 |
+
name = "h11"
|
| 544 |
+
version = "0.14.0"
|
| 545 |
+
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
|
| 546 |
+
category = "main"
|
| 547 |
+
optional = false
|
| 548 |
+
python-versions = ">=3.7"
|
| 549 |
+
|
| 550 |
+
[[package]]
|
| 551 |
+
name = "httpcore"
|
| 552 |
+
version = "0.17.0"
|
| 553 |
+
description = "A minimal low-level HTTP client."
|
| 554 |
+
category = "main"
|
| 555 |
+
optional = false
|
| 556 |
+
python-versions = ">=3.7"
|
| 557 |
+
|
| 558 |
+
[package.dependencies]
|
| 559 |
+
anyio = ">=3.0,<5.0"
|
| 560 |
+
certifi = "*"
|
| 561 |
+
h11 = ">=0.13,<0.15"
|
| 562 |
+
sniffio = ">=1.0.0,<2.0.0"
|
| 563 |
+
|
| 564 |
+
[package.extras]
|
| 565 |
+
http2 = ["h2 (>=3,<5)"]
|
| 566 |
+
socks = ["socksio (>=1.0.0,<2.0.0)"]
|
| 567 |
+
|
| 568 |
+
[[package]]
|
| 569 |
+
name = "httplib2"
|
| 570 |
+
version = "0.22.0"
|
| 571 |
+
description = "A comprehensive HTTP client library."
|
| 572 |
+
category = "main"
|
| 573 |
+
optional = false
|
| 574 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
| 575 |
+
|
| 576 |
+
[package.dependencies]
|
| 577 |
+
pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""}
|
| 578 |
+
|
| 579 |
+
[[package]]
|
| 580 |
+
name = "httpx"
|
| 581 |
+
version = "0.24.0"
|
| 582 |
+
description = "The next generation HTTP client."
|
| 583 |
+
category = "main"
|
| 584 |
+
optional = false
|
| 585 |
+
python-versions = ">=3.7"
|
| 586 |
+
|
| 587 |
+
[package.dependencies]
|
| 588 |
+
certifi = "*"
|
| 589 |
+
httpcore = ">=0.15.0,<0.18.0"
|
| 590 |
+
idna = "*"
|
| 591 |
+
sniffio = "*"
|
| 592 |
+
|
| 593 |
+
[package.extras]
|
| 594 |
+
brotli = ["brotli", "brotlicffi"]
|
| 595 |
+
cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"]
|
| 596 |
+
http2 = ["h2 (>=3,<5)"]
|
| 597 |
+
socks = ["socksio (>=1.0.0,<2.0.0)"]
|
| 598 |
+
|
| 599 |
+
[[package]]
|
| 600 |
+
name = "huggingface-hub"
|
| 601 |
+
version = "0.13.4"
|
| 602 |
+
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
| 603 |
+
category = "main"
|
| 604 |
+
optional = false
|
| 605 |
+
python-versions = ">=3.7.0"
|
| 606 |
+
|
| 607 |
+
[package.dependencies]
|
| 608 |
+
filelock = "*"
|
| 609 |
+
packaging = ">=20.9"
|
| 610 |
+
pyyaml = ">=5.1"
|
| 611 |
+
requests = "*"
|
| 612 |
+
tqdm = ">=4.42.1"
|
| 613 |
+
typing-extensions = ">=3.7.4.3"
|
| 614 |
+
|
| 615 |
+
[package.extras]
|
| 616 |
+
all = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
| 617 |
+
cli = ["InquirerPy (==0.3.4)"]
|
| 618 |
+
dev = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
| 619 |
+
fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"]
|
| 620 |
+
quality = ["black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)"]
|
| 621 |
+
tensorflow = ["tensorflow", "pydot", "graphviz"]
|
| 622 |
+
testing = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow"]
|
| 623 |
+
torch = ["torch"]
|
| 624 |
+
typing = ["types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
| 625 |
+
|
| 626 |
+
[[package]]
|
| 627 |
+
name = "hydra-client"
|
| 628 |
+
version = "0.5.1"
|
| 629 |
+
description = "Client library for ORY Hydra (OAuth2 and OpenID Connect provider)"
|
| 630 |
+
category = "main"
|
| 631 |
+
optional = false
|
| 632 |
+
python-versions = ">=3.7,<4.0"
|
| 633 |
+
|
| 634 |
+
[package.dependencies]
|
| 635 |
+
attrs = ">=19.2,<20.0"
|
| 636 |
+
python-dateutil = ">=2.8,<3.0"
|
| 637 |
+
requests = ">=2.21,<3.0"
|
| 638 |
+
requests-oauthlib = ">=1.0,<2.0"
|
| 639 |
+
|
| 640 |
+
[[package]]
|
| 641 |
+
name = "hydra-core"
|
| 642 |
+
version = "1.3.1"
|
| 643 |
+
description = "A framework for elegantly configuring complex applications"
|
| 644 |
+
category = "main"
|
| 645 |
+
optional = false
|
| 646 |
+
python-versions = "*"
|
| 647 |
+
|
| 648 |
+
[package.dependencies]
|
| 649 |
+
antlr4-python3-runtime = ">=4.9.0,<4.10.0"
|
| 650 |
+
omegaconf = ">=2.2,<2.4"
|
| 651 |
+
packaging = "*"
|
| 652 |
+
|
| 653 |
+
[[package]]
|
| 654 |
+
name = "idna"
|
| 655 |
+
version = "3.4"
|
| 656 |
+
description = "Internationalized Domain Names in Applications (IDNA)"
|
| 657 |
+
category = "main"
|
| 658 |
+
optional = false
|
| 659 |
+
python-versions = ">=3.5"
|
| 660 |
+
|
| 661 |
+
[[package]]
|
| 662 |
+
name = "jinja2"
|
| 663 |
+
version = "3.1.2"
|
| 664 |
+
description = "A very fast and expressive template engine."
|
| 665 |
+
category = "main"
|
| 666 |
+
optional = false
|
| 667 |
+
python-versions = ">=3.7"
|
| 668 |
+
|
| 669 |
+
[package.dependencies]
|
| 670 |
+
MarkupSafe = ">=2.0"
|
| 671 |
+
|
| 672 |
+
[package.extras]
|
| 673 |
+
i18n = ["Babel (>=2.7)"]
|
| 674 |
+
|
| 675 |
+
[[package]]
|
| 676 |
+
name = "jsonschema"
|
| 677 |
+
version = "4.17.3"
|
| 678 |
+
description = "An implementation of JSON Schema validation for Python"
|
| 679 |
+
category = "main"
|
| 680 |
+
optional = false
|
| 681 |
+
python-versions = ">=3.7"
|
| 682 |
+
|
| 683 |
+
[package.dependencies]
|
| 684 |
+
attrs = ">=17.4.0"
|
| 685 |
+
pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2"
|
| 686 |
+
|
| 687 |
+
[package.extras]
|
| 688 |
+
format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
|
| 689 |
+
format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"]
|
| 690 |
+
|
| 691 |
+
[[package]]
|
| 692 |
+
name = "kiwisolver"
|
| 693 |
+
version = "1.4.4"
|
| 694 |
+
description = "A fast implementation of the Cassowary constraint solver"
|
| 695 |
+
category = "main"
|
| 696 |
+
optional = false
|
| 697 |
+
python-versions = ">=3.7"
|
| 698 |
+
|
| 699 |
+
[[package]]
|
| 700 |
+
name = "lightning-utilities"
|
| 701 |
+
version = "0.8.0"
|
| 702 |
+
description = "PyTorch Lightning Sample project."
|
| 703 |
+
category = "main"
|
| 704 |
+
optional = false
|
| 705 |
+
python-versions = ">=3.7"
|
| 706 |
+
|
| 707 |
+
[package.dependencies]
|
| 708 |
+
packaging = ">=17.1"
|
| 709 |
+
typing-extensions = "*"
|
| 710 |
+
|
| 711 |
+
[package.extras]
|
| 712 |
+
cli = ["fire"]
|
| 713 |
+
docs = ["sphinx (>=4.0,<5.0)"]
|
| 714 |
+
test = ["coverage (==6.5.0)"]
|
| 715 |
+
typing = ["mypy (>=1.0.0)"]
|
| 716 |
+
|
| 717 |
+
[[package]]
|
| 718 |
+
name = "linkify-it-py"
|
| 719 |
+
version = "2.0.0"
|
| 720 |
+
description = "Links recognition library with FULL unicode support."
|
| 721 |
+
category = "main"
|
| 722 |
+
optional = false
|
| 723 |
+
python-versions = ">=3.6"
|
| 724 |
+
|
| 725 |
+
[package.dependencies]
|
| 726 |
+
uc-micro-py = "*"
|
| 727 |
+
|
| 728 |
+
[package.extras]
|
| 729 |
+
benchmark = ["pytest", "pytest-benchmark"]
|
| 730 |
+
dev = ["pre-commit", "isort", "flake8", "black"]
|
| 731 |
+
doc = ["sphinx", "sphinx-book-theme", "myst-parser"]
|
| 732 |
+
test = ["coverage", "pytest", "pytest-cov"]
|
| 733 |
+
|
| 734 |
+
[[package]]
|
| 735 |
+
name = "markdown"
|
| 736 |
+
version = "3.4.3"
|
| 737 |
+
description = "Python implementation of John Gruber's Markdown."
|
| 738 |
+
category = "main"
|
| 739 |
+
optional = false
|
| 740 |
+
python-versions = ">=3.7"
|
| 741 |
+
|
| 742 |
+
[package.extras]
|
| 743 |
+
testing = ["coverage", "pyyaml"]
|
| 744 |
+
|
| 745 |
+
[[package]]
|
| 746 |
+
name = "markdown-it-py"
|
| 747 |
+
version = "2.2.0"
|
| 748 |
+
description = "Python port of markdown-it. Markdown parsing, done right!"
|
| 749 |
+
category = "main"
|
| 750 |
+
optional = false
|
| 751 |
+
python-versions = ">=3.7"
|
| 752 |
+
|
| 753 |
+
[package.dependencies]
|
| 754 |
+
linkify-it-py = {version = ">=1,<3", optional = true, markers = "extra == \"linkify\""}
|
| 755 |
+
mdurl = ">=0.1,<1.0"
|
| 756 |
+
|
| 757 |
+
[package.extras]
|
| 758 |
+
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
| 759 |
+
code_style = ["pre-commit (>=3.0,<4.0)"]
|
| 760 |
+
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
| 761 |
+
linkify = ["linkify-it-py (>=1,<3)"]
|
| 762 |
+
plugins = ["mdit-py-plugins"]
|
| 763 |
+
profiling = ["gprof2dot"]
|
| 764 |
+
rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx-book-theme"]
|
| 765 |
+
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
| 766 |
+
|
| 767 |
+
[[package]]
|
| 768 |
+
name = "markupsafe"
|
| 769 |
+
version = "2.1.2"
|
| 770 |
+
description = "Safely add untrusted strings to HTML/XML markup."
|
| 771 |
+
category = "main"
|
| 772 |
+
optional = false
|
| 773 |
+
python-versions = ">=3.7"
|
| 774 |
+
|
| 775 |
+
[[package]]
|
| 776 |
+
name = "matplotlib"
|
| 777 |
+
version = "3.7.1"
|
| 778 |
+
description = "Python plotting package"
|
| 779 |
+
category = "main"
|
| 780 |
+
optional = false
|
| 781 |
+
python-versions = ">=3.8"
|
| 782 |
+
|
| 783 |
+
[package.dependencies]
|
| 784 |
+
contourpy = ">=1.0.1"
|
| 785 |
+
cycler = ">=0.10"
|
| 786 |
+
fonttools = ">=4.22.0"
|
| 787 |
+
kiwisolver = ">=1.0.1"
|
| 788 |
+
numpy = ">=1.20"
|
| 789 |
+
packaging = ">=20.0"
|
| 790 |
+
pillow = ">=6.2.0"
|
| 791 |
+
pyparsing = ">=2.3.1"
|
| 792 |
+
python-dateutil = ">=2.7"
|
| 793 |
+
setuptools_scm = ">=7"
|
| 794 |
+
|
| 795 |
+
[[package]]
|
| 796 |
+
name = "mdit-py-plugins"
|
| 797 |
+
version = "0.3.3"
|
| 798 |
+
description = "Collection of plugins for markdown-it-py"
|
| 799 |
+
category = "main"
|
| 800 |
+
optional = false
|
| 801 |
+
python-versions = ">=3.7"
|
| 802 |
+
|
| 803 |
+
[package.dependencies]
|
| 804 |
+
markdown-it-py = ">=1.0.0,<3.0.0"
|
| 805 |
+
|
| 806 |
+
[package.extras]
|
| 807 |
+
code_style = ["pre-commit"]
|
| 808 |
+
rtd = ["attrs", "myst-parser (>=0.16.1,<0.17.0)", "sphinx-book-theme (>=0.1.0,<0.2.0)"]
|
| 809 |
+
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
| 810 |
+
|
| 811 |
+
[[package]]
|
| 812 |
+
name = "mdurl"
|
| 813 |
+
version = "0.1.2"
|
| 814 |
+
description = "Markdown URL utilities"
|
| 815 |
+
category = "main"
|
| 816 |
+
optional = false
|
| 817 |
+
python-versions = ">=3.7"
|
| 818 |
+
|
| 819 |
+
[[package]]
|
| 820 |
+
name = "multidict"
|
| 821 |
+
version = "6.0.4"
|
| 822 |
+
description = "multidict implementation"
|
| 823 |
+
category = "main"
|
| 824 |
+
optional = false
|
| 825 |
+
python-versions = ">=3.7"
|
| 826 |
+
|
| 827 |
+
[[package]]
|
| 828 |
+
name = "numpy"
|
| 829 |
+
version = "1.24.2"
|
| 830 |
+
description = "Fundamental package for array computing in Python"
|
| 831 |
+
category = "main"
|
| 832 |
+
optional = false
|
| 833 |
+
python-versions = ">=3.8"
|
| 834 |
+
|
| 835 |
+
[[package]]
|
| 836 |
+
name = "nvidia-cublas-cu11"
|
| 837 |
+
version = "11.10.3.66"
|
| 838 |
+
description = "CUBLAS native runtime libraries"
|
| 839 |
+
category = "main"
|
| 840 |
+
optional = false
|
| 841 |
+
python-versions = ">=3"
|
| 842 |
+
|
| 843 |
+
[[package]]
|
| 844 |
+
name = "nvidia-cuda-nvrtc-cu11"
|
| 845 |
+
version = "11.7.99"
|
| 846 |
+
description = "NVRTC native runtime libraries"
|
| 847 |
+
category = "main"
|
| 848 |
+
optional = false
|
| 849 |
+
python-versions = ">=3"
|
| 850 |
+
|
| 851 |
+
[[package]]
|
| 852 |
+
name = "nvidia-cuda-runtime-cu11"
|
| 853 |
+
version = "11.7.99"
|
| 854 |
+
description = "CUDA Runtime native Libraries"
|
| 855 |
+
category = "main"
|
| 856 |
+
optional = false
|
| 857 |
+
python-versions = ">=3"
|
| 858 |
+
|
| 859 |
+
[[package]]
|
| 860 |
+
name = "nvidia-cudnn-cu11"
|
| 861 |
+
version = "8.5.0.96"
|
| 862 |
+
description = "cuDNN runtime libraries"
|
| 863 |
+
category = "main"
|
| 864 |
+
optional = false
|
| 865 |
+
python-versions = ">=3"
|
| 866 |
+
|
| 867 |
+
[[package]]
|
| 868 |
+
name = "oauthlib"
|
| 869 |
+
version = "3.2.2"
|
| 870 |
+
description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
|
| 871 |
+
category = "main"
|
| 872 |
+
optional = false
|
| 873 |
+
python-versions = ">=3.6"
|
| 874 |
+
|
| 875 |
+
[package.extras]
|
| 876 |
+
rsa = ["cryptography (>=3.0.0)"]
|
| 877 |
+
signals = ["blinker (>=1.4.0)"]
|
| 878 |
+
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
| 879 |
+
|
| 880 |
+
[[package]]
|
| 881 |
+
name = "omegaconf"
|
| 882 |
+
version = "2.3.0"
|
| 883 |
+
description = "A flexible configuration library"
|
| 884 |
+
category = "main"
|
| 885 |
+
optional = false
|
| 886 |
+
python-versions = ">=3.6"
|
| 887 |
+
|
| 888 |
+
[package.dependencies]
|
| 889 |
+
antlr4-python3-runtime = ">=4.9.0,<4.10.0"
|
| 890 |
+
PyYAML = ">=5.1.0"
|
| 891 |
+
|
| 892 |
+
[[package]]
|
| 893 |
+
name = "opencv-python"
|
| 894 |
+
version = "4.7.0.72"
|
| 895 |
+
description = "Wrapper package for OpenCV python bindings."
|
| 896 |
+
category = "main"
|
| 897 |
+
optional = false
|
| 898 |
+
python-versions = ">=3.6"
|
| 899 |
+
|
| 900 |
+
[package.dependencies]
|
| 901 |
+
numpy = [
|
| 902 |
+
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
|
| 903 |
+
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
|
| 904 |
+
{version = ">=1.22.0", markers = "python_version >= \"3.11\""},
|
| 905 |
+
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
|
| 906 |
+
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
|
| 907 |
+
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
|
| 908 |
+
]
|
| 909 |
+
|
| 910 |
+
[[package]]
|
| 911 |
+
name = "orjson"
|
| 912 |
+
version = "3.8.10"
|
| 913 |
+
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
|
| 914 |
+
category = "main"
|
| 915 |
+
optional = false
|
| 916 |
+
python-versions = ">= 3.7"
|
| 917 |
+
|
| 918 |
+
[[package]]
|
| 919 |
+
name = "packaging"
|
| 920 |
+
version = "23.1"
|
| 921 |
+
description = "Core utilities for Python packages"
|
| 922 |
+
category = "main"
|
| 923 |
+
optional = false
|
| 924 |
+
python-versions = ">=3.7"
|
| 925 |
+
|
| 926 |
+
[[package]]
|
| 927 |
+
name = "pandas"
|
| 928 |
+
version = "2.0.0"
|
| 929 |
+
description = "Powerful data structures for data analysis, time series, and statistics"
|
| 930 |
+
category = "main"
|
| 931 |
+
optional = false
|
| 932 |
+
python-versions = ">=3.8"
|
| 933 |
+
|
| 934 |
+
[package.dependencies]
|
| 935 |
+
numpy = [
|
| 936 |
+
{version = ">=1.21.0", markers = "python_version >= \"3.10\""},
|
| 937 |
+
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
| 938 |
+
]
|
| 939 |
+
python-dateutil = ">=2.8.2"
|
| 940 |
+
pytz = ">=2020.1"
|
| 941 |
+
tzdata = ">=2022.1"
|
| 942 |
+
|
| 943 |
+
[package.extras]
|
| 944 |
+
all = ["beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "PyQt5 (>=5.15.1)", "pyreadstat (>=1.1.2)", "pytest (>=7.0.0)", "pytest-xdist (>=2.2.0)", "pytest-asyncio (>=0.17.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "scipy (>=1.7.1)", "s3fs (>=2021.08.0)", "SQLAlchemy (>=1.4.16)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"]
|
| 945 |
+
aws = ["s3fs (>=2021.08.0)"]
|
| 946 |
+
clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"]
|
| 947 |
+
compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"]
|
| 948 |
+
computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"]
|
| 949 |
+
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"]
|
| 950 |
+
feather = ["pyarrow (>=7.0.0)"]
|
| 951 |
+
fss = ["fsspec (>=2021.07.0)"]
|
| 952 |
+
gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"]
|
| 953 |
+
hdf5 = ["tables (>=3.6.1)"]
|
| 954 |
+
html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"]
|
| 955 |
+
mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"]
|
| 956 |
+
output_formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"]
|
| 957 |
+
parquet = ["pyarrow (>=7.0.0)"]
|
| 958 |
+
performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"]
|
| 959 |
+
plot = ["matplotlib (>=3.6.1)"]
|
| 960 |
+
postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"]
|
| 961 |
+
spss = ["pyreadstat (>=1.1.2)"]
|
| 962 |
+
sql-other = ["SQLAlchemy (>=1.4.16)"]
|
| 963 |
+
test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-xdist (>=2.2.0)", "pytest-asyncio (>=0.17.0)"]
|
| 964 |
+
xml = ["lxml (>=4.6.3)"]
|
| 965 |
+
|
| 966 |
+
[[package]]
|
| 967 |
+
name = "pillow"
|
| 968 |
+
version = "8.4.0"
|
| 969 |
+
description = "Python Imaging Library (Fork)"
|
| 970 |
+
category = "main"
|
| 971 |
+
optional = false
|
| 972 |
+
python-versions = ">=3.6"
|
| 973 |
+
|
| 974 |
+
[[package]]
|
| 975 |
+
name = "plotly"
|
| 976 |
+
version = "5.14.1"
|
| 977 |
+
description = "An open-source, interactive data visualization library for Python"
|
| 978 |
+
category = "main"
|
| 979 |
+
optional = false
|
| 980 |
+
python-versions = ">=3.6"
|
| 981 |
+
|
| 982 |
+
[package.dependencies]
|
| 983 |
+
packaging = "*"
|
| 984 |
+
tenacity = ">=6.2.0"
|
| 985 |
+
|
| 986 |
+
[[package]]
|
| 987 |
+
name = "protobuf"
|
| 988 |
+
version = "3.20.3"
|
| 989 |
+
description = "Protocol Buffers"
|
| 990 |
+
category = "main"
|
| 991 |
+
optional = false
|
| 992 |
+
python-versions = ">=3.7"
|
| 993 |
+
|
| 994 |
+
[[package]]
|
| 995 |
+
name = "pyasn1"
|
| 996 |
+
version = "0.4.8"
|
| 997 |
+
description = "ASN.1 types and codecs"
|
| 998 |
+
category = "main"
|
| 999 |
+
optional = false
|
| 1000 |
+
python-versions = "*"
|
| 1001 |
+
|
| 1002 |
+
[[package]]
|
| 1003 |
+
name = "pyasn1-modules"
|
| 1004 |
+
version = "0.2.8"
|
| 1005 |
+
description = "A collection of ASN.1-based protocols modules."
|
| 1006 |
+
category = "main"
|
| 1007 |
+
optional = false
|
| 1008 |
+
python-versions = "*"
|
| 1009 |
+
|
| 1010 |
+
[package.dependencies]
|
| 1011 |
+
pyasn1 = ">=0.4.6,<0.5.0"
|
| 1012 |
+
|
| 1013 |
+
[[package]]
|
| 1014 |
+
name = "pydantic"
|
| 1015 |
+
version = "1.10.7"
|
| 1016 |
+
description = "Data validation and settings management using python type hints"
|
| 1017 |
+
category = "main"
|
| 1018 |
+
optional = false
|
| 1019 |
+
python-versions = ">=3.7"
|
| 1020 |
+
|
| 1021 |
+
[package.dependencies]
|
| 1022 |
+
typing-extensions = ">=4.2.0"
|
| 1023 |
+
|
| 1024 |
+
[package.extras]
|
| 1025 |
+
dotenv = ["python-dotenv (>=0.10.4)"]
|
| 1026 |
+
email = ["email-validator (>=1.0.3)"]
|
| 1027 |
+
|
| 1028 |
+
[[package]]
|
| 1029 |
+
name = "pydub"
|
| 1030 |
+
version = "0.25.1"
|
| 1031 |
+
description = "Manipulate audio with an simple and easy high level interface"
|
| 1032 |
+
category = "main"
|
| 1033 |
+
optional = false
|
| 1034 |
+
python-versions = "*"
|
| 1035 |
+
|
| 1036 |
+
[[package]]
|
| 1037 |
+
name = "pyparsing"
|
| 1038 |
+
version = "3.0.9"
|
| 1039 |
+
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
|
| 1040 |
+
category = "main"
|
| 1041 |
+
optional = false
|
| 1042 |
+
python-versions = ">=3.6.8"
|
| 1043 |
+
|
| 1044 |
+
[package.extras]
|
| 1045 |
+
diagrams = ["railroad-diagrams", "jinja2"]
|
| 1046 |
+
|
| 1047 |
+
[[package]]
|
| 1048 |
+
name = "pyrsistent"
|
| 1049 |
+
version = "0.19.3"
|
| 1050 |
+
description = "Persistent/Functional/Immutable data structures"
|
| 1051 |
+
category = "main"
|
| 1052 |
+
optional = false
|
| 1053 |
+
python-versions = ">=3.7"
|
| 1054 |
+
|
| 1055 |
+
[[package]]
|
| 1056 |
+
name = "python-dateutil"
|
| 1057 |
+
version = "2.8.2"
|
| 1058 |
+
description = "Extensions to the standard Python datetime module"
|
| 1059 |
+
category = "main"
|
| 1060 |
+
optional = false
|
| 1061 |
+
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
| 1062 |
+
|
| 1063 |
+
[package.dependencies]
|
| 1064 |
+
six = ">=1.5"
|
| 1065 |
+
|
| 1066 |
+
[[package]]
|
| 1067 |
+
name = "python-multipart"
|
| 1068 |
+
version = "0.0.6"
|
| 1069 |
+
description = "A streaming multipart parser for Python"
|
| 1070 |
+
category = "main"
|
| 1071 |
+
optional = false
|
| 1072 |
+
python-versions = ">=3.7"
|
| 1073 |
+
|
| 1074 |
+
[package.extras]
|
| 1075 |
+
dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pytest (==7.2.0)", "pyyaml (==5.1)"]
|
| 1076 |
+
|
| 1077 |
+
[[package]]
|
| 1078 |
+
name = "pytorch-lightning"
|
| 1079 |
+
version = "1.9.0"
|
| 1080 |
+
description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
|
| 1081 |
+
category = "main"
|
| 1082 |
+
optional = false
|
| 1083 |
+
python-versions = ">=3.7"
|
| 1084 |
+
|
| 1085 |
+
[package.dependencies]
|
| 1086 |
+
fsspec = {version = ">2021.06.0", extras = ["http"]}
|
| 1087 |
+
lightning-utilities = ">=0.4.2"
|
| 1088 |
+
numpy = ">=1.17.2"
|
| 1089 |
+
packaging = ">=17.1"
|
| 1090 |
+
PyYAML = ">=5.4"
|
| 1091 |
+
torch = ">=1.10.0"
|
| 1092 |
+
torchmetrics = ">=0.7.0"
|
| 1093 |
+
tqdm = ">=4.57.0"
|
| 1094 |
+
typing-extensions = ">=4.0.0"
|
| 1095 |
+
|
| 1096 |
+
[package.extras]
|
| 1097 |
+
all = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)", "hivemind (==1.1.5)"]
|
| 1098 |
+
deepspeed = ["deepspeed (>=0.6.0)"]
|
| 1099 |
+
dev = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)", "coverage (==6.5.0)", "codecov (==2.1.12)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "pre-commit (==2.20.0)", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime (<1.14.0)", "psutil (<5.9.5)", "pandas (>1.0)", "fastapi (<0.87.0)", "uvicorn (<0.19.1)", "tensorboard (>=2.9.1)", "protobuf (<=3.20.1)", "hivemind (==1.1.5)"]
|
| 1100 |
+
examples = ["torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)"]
|
| 1101 |
+
extra = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)"]
|
| 1102 |
+
fairscale = ["fairscale (>=0.4.5)"]
|
| 1103 |
+
hivemind = ["hivemind (==1.1.5)"]
|
| 1104 |
+
horovod = ["horovod (>=0.21.2,!=0.24.0)"]
|
| 1105 |
+
strategies = ["fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "hivemind (==1.1.5)"]
|
| 1106 |
+
test = ["coverage (==6.5.0)", "codecov (==2.1.12)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "pre-commit (==2.20.0)", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime (<1.14.0)", "psutil (<5.9.5)", "pandas (>1.0)", "fastapi (<0.87.0)", "uvicorn (<0.19.1)", "tensorboard (>=2.9.1)", "protobuf (<=3.20.1)"]
|
| 1107 |
+
|
| 1108 |
+
[[package]]
|
| 1109 |
+
name = "pytz"
|
| 1110 |
+
version = "2023.3"
|
| 1111 |
+
description = "World timezone definitions, modern and historical"
|
| 1112 |
+
category = "main"
|
| 1113 |
+
optional = false
|
| 1114 |
+
python-versions = "*"
|
| 1115 |
+
|
| 1116 |
+
[[package]]
|
| 1117 |
+
name = "pyyaml"
|
| 1118 |
+
version = "6.0"
|
| 1119 |
+
description = "YAML parser and emitter for Python"
|
| 1120 |
+
category = "main"
|
| 1121 |
+
optional = false
|
| 1122 |
+
python-versions = ">=3.6"
|
| 1123 |
+
|
| 1124 |
+
[[package]]
|
| 1125 |
+
name = "requests"
|
| 1126 |
+
version = "2.28.2"
|
| 1127 |
+
description = "Python HTTP for Humans."
|
| 1128 |
+
category = "main"
|
| 1129 |
+
optional = false
|
| 1130 |
+
python-versions = ">=3.7, <4"
|
| 1131 |
+
|
| 1132 |
+
[package.dependencies]
|
| 1133 |
+
certifi = ">=2017.4.17"
|
| 1134 |
+
charset-normalizer = ">=2,<4"
|
| 1135 |
+
idna = ">=2.5,<4"
|
| 1136 |
+
urllib3 = ">=1.21.1,<1.27"
|
| 1137 |
+
|
| 1138 |
+
[package.extras]
|
| 1139 |
+
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
| 1140 |
+
use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"]
|
| 1141 |
+
|
| 1142 |
+
[[package]]
|
| 1143 |
+
name = "requests-oauthlib"
|
| 1144 |
+
version = "1.3.1"
|
| 1145 |
+
description = "OAuthlib authentication support for Requests."
|
| 1146 |
+
category = "main"
|
| 1147 |
+
optional = false
|
| 1148 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
| 1149 |
+
|
| 1150 |
+
[package.dependencies]
|
| 1151 |
+
oauthlib = ">=3.0.0"
|
| 1152 |
+
requests = ">=2.0.0"
|
| 1153 |
+
|
| 1154 |
+
[package.extras]
|
| 1155 |
+
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
| 1156 |
+
|
| 1157 |
+
[[package]]
|
| 1158 |
+
name = "rsa"
|
| 1159 |
+
version = "4.9"
|
| 1160 |
+
description = "Pure-Python RSA implementation"
|
| 1161 |
+
category = "main"
|
| 1162 |
+
optional = false
|
| 1163 |
+
python-versions = ">=3.6,<4"
|
| 1164 |
+
|
| 1165 |
+
[package.dependencies]
|
| 1166 |
+
pyasn1 = ">=0.1.3"
|
| 1167 |
+
|
| 1168 |
+
[[package]]
|
| 1169 |
+
name = "scipy"
|
| 1170 |
+
version = "1.10.1"
|
| 1171 |
+
description = "Fundamental algorithms for scientific computing in Python"
|
| 1172 |
+
category = "main"
|
| 1173 |
+
optional = false
|
| 1174 |
+
python-versions = "<3.12,>=3.8"
|
| 1175 |
+
|
| 1176 |
+
[package.dependencies]
|
| 1177 |
+
numpy = ">=1.19.5,<1.27.0"
|
| 1178 |
+
|
| 1179 |
+
[package.extras]
|
| 1180 |
+
test = ["pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack", "pooch"]
|
| 1181 |
+
doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-design (>=0.2.0)", "matplotlib (>2)", "numpydoc"]
|
| 1182 |
+
dev = ["mypy", "typing-extensions", "pycodestyle", "flake8", "rich-click", "click", "doit (>=0.36.0)", "pydevtool"]
|
| 1183 |
+
|
| 1184 |
+
[[package]]
|
| 1185 |
+
name = "seaborn"
|
| 1186 |
+
version = "0.12.2"
|
| 1187 |
+
description = "Statistical data visualization"
|
| 1188 |
+
category = "main"
|
| 1189 |
+
optional = false
|
| 1190 |
+
python-versions = ">=3.7"
|
| 1191 |
+
|
| 1192 |
+
[package.dependencies]
|
| 1193 |
+
matplotlib = ">=3.1,<3.6.1 || >3.6.1"
|
| 1194 |
+
numpy = ">=1.17,<1.24.0 || >1.24.0"
|
| 1195 |
+
pandas = ">=0.25"
|
| 1196 |
+
|
| 1197 |
+
[package.extras]
|
| 1198 |
+
dev = ["pytest", "pytest-cov", "pytest-xdist", "flake8", "mypy", "pandas-stubs", "pre-commit", "flit"]
|
| 1199 |
+
docs = ["numpydoc", "nbconvert", "ipykernel", "sphinx-copybutton", "sphinx-issues", "sphinx-design", "pyyaml", "pydata_sphinx_theme (==0.10.0rc2)"]
|
| 1200 |
+
stats = ["scipy (>=1.3)", "statsmodels (>=0.10)"]
|
| 1201 |
+
|
| 1202 |
+
[[package]]
|
| 1203 |
+
name = "semantic-version"
|
| 1204 |
+
version = "2.10.0"
|
| 1205 |
+
description = "A library implementing the 'SemVer' scheme."
|
| 1206 |
+
category = "main"
|
| 1207 |
+
optional = false
|
| 1208 |
+
python-versions = ">=2.7"
|
| 1209 |
+
|
| 1210 |
+
[package.extras]
|
| 1211 |
+
dev = ["Django (>=1.11)", "nose2", "tox", "check-manifest", "coverage", "flake8", "wheel", "zest.releaser", "readme-renderer (<25.0)", "colorama (<=0.4.1)"]
|
| 1212 |
+
doc = ["sphinx", "sphinx-rtd-theme"]
|
| 1213 |
+
|
| 1214 |
+
[[package]]
|
| 1215 |
+
name = "setuptools-scm"
|
| 1216 |
+
version = "7.1.0"
|
| 1217 |
+
description = "the blessed package to manage your versions by scm tags"
|
| 1218 |
+
category = "main"
|
| 1219 |
+
optional = false
|
| 1220 |
+
python-versions = ">=3.7"
|
| 1221 |
+
|
| 1222 |
+
[package.dependencies]
|
| 1223 |
+
packaging = ">=20.0"
|
| 1224 |
+
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
| 1225 |
+
typing-extensions = "*"
|
| 1226 |
+
|
| 1227 |
+
[package.extras]
|
| 1228 |
+
test = ["pytest (>=6.2)", "virtualenv (>20)"]
|
| 1229 |
+
toml = ["setuptools (>=42)"]
|
| 1230 |
+
|
| 1231 |
+
[[package]]
|
| 1232 |
+
name = "six"
|
| 1233 |
+
version = "1.16.0"
|
| 1234 |
+
description = "Python 2 and 3 compatibility utilities"
|
| 1235 |
+
category = "main"
|
| 1236 |
+
optional = false
|
| 1237 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
|
| 1238 |
+
|
| 1239 |
+
[[package]]
|
| 1240 |
+
name = "sniffio"
|
| 1241 |
+
version = "1.3.0"
|
| 1242 |
+
description = "Sniff out which async library your code is running under"
|
| 1243 |
+
category = "main"
|
| 1244 |
+
optional = false
|
| 1245 |
+
python-versions = ">=3.7"
|
| 1246 |
+
|
| 1247 |
+
[[package]]
|
| 1248 |
+
name = "starlette"
|
| 1249 |
+
version = "0.26.1"
|
| 1250 |
+
description = "The little ASGI library that shines."
|
| 1251 |
+
category = "main"
|
| 1252 |
+
optional = false
|
| 1253 |
+
python-versions = ">=3.7"
|
| 1254 |
+
|
| 1255 |
+
[package.dependencies]
|
| 1256 |
+
anyio = ">=3.4.0,<5"
|
| 1257 |
+
|
| 1258 |
+
[package.extras]
|
| 1259 |
+
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"]
|
| 1260 |
+
|
| 1261 |
+
[[package]]
|
| 1262 |
+
name = "tenacity"
|
| 1263 |
+
version = "8.2.2"
|
| 1264 |
+
description = "Retry code until it succeeds"
|
| 1265 |
+
category = "main"
|
| 1266 |
+
optional = false
|
| 1267 |
+
python-versions = ">=3.6"
|
| 1268 |
+
|
| 1269 |
+
[package.extras]
|
| 1270 |
+
doc = ["reno", "sphinx", "tornado (>=4.5)"]
|
| 1271 |
+
|
| 1272 |
+
[[package]]
|
| 1273 |
+
name = "tensorboard"
|
| 1274 |
+
version = "2.11.2"
|
| 1275 |
+
description = "TensorBoard lets you watch Tensors Flow"
|
| 1276 |
+
category = "main"
|
| 1277 |
+
optional = false
|
| 1278 |
+
python-versions = ">=3.7"
|
| 1279 |
+
|
| 1280 |
+
[package.dependencies]
|
| 1281 |
+
absl-py = ">=0.4"
|
| 1282 |
+
google-auth = ">=1.6.3,<3"
|
| 1283 |
+
google-auth-oauthlib = ">=0.4.1,<0.5"
|
| 1284 |
+
grpcio = ">=1.24.3"
|
| 1285 |
+
markdown = ">=2.6.8"
|
| 1286 |
+
numpy = ">=1.12.0"
|
| 1287 |
+
protobuf = ">=3.9.2,<4"
|
| 1288 |
+
requests = ">=2.21.0,<3"
|
| 1289 |
+
tensorboard-data-server = ">=0.6.0,<0.7.0"
|
| 1290 |
+
tensorboard-plugin-wit = ">=1.6.0"
|
| 1291 |
+
werkzeug = ">=1.0.1"
|
| 1292 |
+
|
| 1293 |
+
[[package]]
|
| 1294 |
+
name = "tensorboard-data-server"
|
| 1295 |
+
version = "0.6.1"
|
| 1296 |
+
description = "Fast data loading for TensorBoard"
|
| 1297 |
+
category = "main"
|
| 1298 |
+
optional = false
|
| 1299 |
+
python-versions = ">=3.6"
|
| 1300 |
+
|
| 1301 |
+
[[package]]
|
| 1302 |
+
name = "tensorboard-plugin-wit"
|
| 1303 |
+
version = "1.8.1"
|
| 1304 |
+
description = "What-If Tool TensorBoard plugin."
|
| 1305 |
+
category = "main"
|
| 1306 |
+
optional = false
|
| 1307 |
+
python-versions = "*"
|
| 1308 |
+
|
| 1309 |
+
[[package]]
|
| 1310 |
+
name = "tomli"
|
| 1311 |
+
version = "2.0.1"
|
| 1312 |
+
description = "A lil' TOML parser"
|
| 1313 |
+
category = "main"
|
| 1314 |
+
optional = false
|
| 1315 |
+
python-versions = ">=3.7"
|
| 1316 |
+
|
| 1317 |
+
[[package]]
|
| 1318 |
+
name = "toolz"
|
| 1319 |
+
version = "0.12.0"
|
| 1320 |
+
description = "List processing tools and functional utilities"
|
| 1321 |
+
category = "main"
|
| 1322 |
+
optional = false
|
| 1323 |
+
python-versions = ">=3.5"
|
| 1324 |
+
|
| 1325 |
+
[[package]]
|
| 1326 |
+
name = "torch"
|
| 1327 |
+
version = "1.13.1"
|
| 1328 |
+
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
| 1329 |
+
category = "main"
|
| 1330 |
+
optional = false
|
| 1331 |
+
python-versions = ">=3.7.0"
|
| 1332 |
+
|
| 1333 |
+
[package.dependencies]
|
| 1334 |
+
nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""}
|
| 1335 |
+
nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
|
| 1336 |
+
nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
|
| 1337 |
+
nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""}
|
| 1338 |
+
typing-extensions = "*"
|
| 1339 |
+
|
| 1340 |
+
[package.extras]
|
| 1341 |
+
opt-einsum = ["opt-einsum (>=3.3)"]
|
| 1342 |
+
|
| 1343 |
+
[[package]]
|
| 1344 |
+
name = "torchmetrics"
|
| 1345 |
+
version = "0.11.0"
|
| 1346 |
+
description = "PyTorch native Metrics"
|
| 1347 |
+
category = "main"
|
| 1348 |
+
optional = false
|
| 1349 |
+
python-versions = ">=3.7"
|
| 1350 |
+
|
| 1351 |
+
[package.dependencies]
|
| 1352 |
+
numpy = ">=1.17.2"
|
| 1353 |
+
packaging = "*"
|
| 1354 |
+
torch = ">=1.8.1"
|
| 1355 |
+
|
| 1356 |
+
[package.extras]
|
| 1357 |
+
all = ["pystoi", "torchvision (>=0.8)", "pycocotools", "scipy", "torch-fidelity", "torchvision", "lpips", "pytorch-lightning (>=1.5)", "transformers (>=4.10.0)", "regex (>=2021.9.24)", "nltk (>=3.6)", "tqdm (>=4.41.0)"]
|
| 1358 |
+
audio = ["pystoi"]
|
| 1359 |
+
detection = ["torchvision (>=0.8)", "pycocotools"]
|
| 1360 |
+
docs = ["sphinx-autodoc-typehints (>=1.0)", "nbsphinx (>=0.8)", "docutils (>=0.16)", "sphinx-togglebutton (>=0.2)", "pandoc (>=1.0)", "myst-parser", "sphinx-paramlinks (>=0.5.1)", "sphinxcontrib-fulltoc (>=1.0)", "sphinxcontrib-mockautodoc", "sphinx-copybutton (>=0.3)", "sphinx (>=4.0,<5.0)"]
|
| 1361 |
+
image = ["scipy", "torch-fidelity", "torchvision", "lpips"]
|
| 1362 |
+
integrate = ["pytorch-lightning (>=1.5)"]
|
| 1363 |
+
multimodal = ["transformers (>=4.10.0)"]
|
| 1364 |
+
test = ["types-protobuf", "rouge-score (>=0.0.4)", "bert-score (==0.3.10)", "requests", "mir-eval (>=0.6)", "jiwer (>=2.3.0)", "scikit-learn (>1.0,<1.1.1)", "check-manifest", "types-tabulate", "pytest-timeout", "types-emoji", "pycocotools", "coverage (>5.2)", "pytest (>=6.0.0,<7.0.0)", "types-six", "kornia (>=0.6.7)", "phmdoctest (>=1.1.1)", "pandas", "pytest-cov (>2.10)", "cloudpickle (>=1.3)", "pre-commit (>=1.0)", "scipy", "psutil", "mypy (==0.982)", "types-requests", "pytest-rerunfailures (>=10.0)", "types-pyyaml", "types-setuptools", "sacrebleu (>=2.0.0)", "netcal", "pytorch-msssim (==0.2.1)", "transformers (>4.4.0)", "fast-bss-eval (>=0.1.0)", "fire", "scikit-image (>0.17.1)", "dython", "torch-complex", "pytest-doctestplus (>=0.9.0)", "huggingface-hub (<0.7)", "pypesq (>1.2)"]
|
| 1365 |
+
text = ["regex (>=2021.9.24)", "nltk (>=3.6)", "tqdm (>=4.41.0)"]
|
| 1366 |
+
|
| 1367 |
+
[[package]]
|
| 1368 |
+
name = "torchvision"
|
| 1369 |
+
version = "0.14.1"
|
| 1370 |
+
description = "image and video datasets and models for torch deep learning"
|
| 1371 |
+
category = "main"
|
| 1372 |
+
optional = false
|
| 1373 |
+
python-versions = ">=3.7"
|
| 1374 |
+
|
| 1375 |
+
[package.dependencies]
|
| 1376 |
+
numpy = "*"
|
| 1377 |
+
pillow = ">=5.3.0,<8.3.0 || >=8.4.0"
|
| 1378 |
+
requests = "*"
|
| 1379 |
+
torch = "1.13.1"
|
| 1380 |
+
typing-extensions = "*"
|
| 1381 |
+
|
| 1382 |
+
[package.extras]
|
| 1383 |
+
scipy = ["scipy"]
|
| 1384 |
+
|
| 1385 |
+
[[package]]
|
| 1386 |
+
name = "tqdm"
|
| 1387 |
+
version = "4.65.0"
|
| 1388 |
+
description = "Fast, Extensible Progress Meter"
|
| 1389 |
+
category = "main"
|
| 1390 |
+
optional = false
|
| 1391 |
+
python-versions = ">=3.7"
|
| 1392 |
+
|
| 1393 |
+
[package.dependencies]
|
| 1394 |
+
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
| 1395 |
+
|
| 1396 |
+
[package.extras]
|
| 1397 |
+
dev = ["py-make (>=0.1.0)", "twine", "wheel"]
|
| 1398 |
+
notebook = ["ipywidgets (>=6)"]
|
| 1399 |
+
slack = ["slack-sdk"]
|
| 1400 |
+
telegram = ["requests"]
|
| 1401 |
+
|
| 1402 |
+
[[package]]
|
| 1403 |
+
name = "typing-extensions"
|
| 1404 |
+
version = "4.5.0"
|
| 1405 |
+
description = "Backported and Experimental Type Hints for Python 3.7+"
|
| 1406 |
+
category = "main"
|
| 1407 |
+
optional = false
|
| 1408 |
+
python-versions = ">=3.7"
|
| 1409 |
+
|
| 1410 |
+
[[package]]
|
| 1411 |
+
name = "tzdata"
|
| 1412 |
+
version = "2023.3"
|
| 1413 |
+
description = "Provider of IANA time zone data"
|
| 1414 |
+
category = "main"
|
| 1415 |
+
optional = false
|
| 1416 |
+
python-versions = ">=2"
|
| 1417 |
+
|
| 1418 |
+
[[package]]
|
| 1419 |
+
name = "uc-micro-py"
|
| 1420 |
+
version = "1.0.1"
|
| 1421 |
+
description = "Micro subset of unicode data files for linkify-it-py projects."
|
| 1422 |
+
category = "main"
|
| 1423 |
+
optional = false
|
| 1424 |
+
python-versions = ">=3.6"
|
| 1425 |
+
|
| 1426 |
+
[package.extras]
|
| 1427 |
+
test = ["coverage", "pytest", "pytest-cov"]
|
| 1428 |
+
|
| 1429 |
+
[[package]]
|
| 1430 |
+
name = "uritemplate"
|
| 1431 |
+
version = "4.1.1"
|
| 1432 |
+
description = "Implementation of RFC 6570 URI Templates"
|
| 1433 |
+
category = "main"
|
| 1434 |
+
optional = false
|
| 1435 |
+
python-versions = ">=3.6"
|
| 1436 |
+
|
| 1437 |
+
[[package]]
|
| 1438 |
+
name = "urllib3"
|
| 1439 |
+
version = "1.26.15"
|
| 1440 |
+
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
| 1441 |
+
category = "main"
|
| 1442 |
+
optional = false
|
| 1443 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
|
| 1444 |
+
|
| 1445 |
+
[package.extras]
|
| 1446 |
+
brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"]
|
| 1447 |
+
secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "urllib3-secure-extra", "ipaddress"]
|
| 1448 |
+
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
|
| 1449 |
+
|
| 1450 |
+
[[package]]
|
| 1451 |
+
name = "uvicorn"
|
| 1452 |
+
version = "0.21.1"
|
| 1453 |
+
description = "The lightning-fast ASGI server."
|
| 1454 |
+
category = "main"
|
| 1455 |
+
optional = false
|
| 1456 |
+
python-versions = ">=3.7"
|
| 1457 |
+
|
| 1458 |
+
[package.dependencies]
|
| 1459 |
+
click = ">=7.0"
|
| 1460 |
+
h11 = ">=0.8"
|
| 1461 |
+
|
| 1462 |
+
[package.extras]
|
| 1463 |
+
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
| 1464 |
+
|
| 1465 |
+
[[package]]
|
| 1466 |
+
name = "websockets"
|
| 1467 |
+
version = "11.0.1"
|
| 1468 |
+
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
|
| 1469 |
+
category = "main"
|
| 1470 |
+
optional = false
|
| 1471 |
+
python-versions = ">=3.7"
|
| 1472 |
+
|
| 1473 |
+
[[package]]
|
| 1474 |
+
name = "werkzeug"
|
| 1475 |
+
version = "2.2.3"
|
| 1476 |
+
description = "The comprehensive WSGI web application library."
|
| 1477 |
+
category = "main"
|
| 1478 |
+
optional = false
|
| 1479 |
+
python-versions = ">=3.7"
|
| 1480 |
+
|
| 1481 |
+
[package.dependencies]
|
| 1482 |
+
MarkupSafe = ">=2.1.1"
|
| 1483 |
+
|
| 1484 |
+
[package.extras]
|
| 1485 |
+
watchdog = ["watchdog"]
|
| 1486 |
+
|
| 1487 |
+
[[package]]
|
| 1488 |
+
name = "wget"
|
| 1489 |
+
version = "3.2"
|
| 1490 |
+
description = "pure python download utility"
|
| 1491 |
+
category = "main"
|
| 1492 |
+
optional = false
|
| 1493 |
+
python-versions = "*"
|
| 1494 |
+
|
| 1495 |
+
[[package]]
|
| 1496 |
+
name = "yarl"
|
| 1497 |
+
version = "1.8.2"
|
| 1498 |
+
description = "Yet another URL library"
|
| 1499 |
+
category = "main"
|
| 1500 |
+
optional = false
|
| 1501 |
+
python-versions = ">=3.7"
|
| 1502 |
+
|
| 1503 |
+
[package.dependencies]
|
| 1504 |
+
idna = ">=2.0"
|
| 1505 |
+
multidict = ">=4.0"
|
| 1506 |
+
|
| 1507 |
+
[metadata]
|
| 1508 |
+
lock-version = "1.1"
|
| 1509 |
+
python-versions = ">=3.10,<3.12"
|
| 1510 |
+
content-hash = "17cec1f61fed3b070c0b744eeecc9dbaed1ea06d758238ac84f108545ab14a21"
|
| 1511 |
+
|
| 1512 |
+
[metadata.files]
|
| 1513 |
+
absl-py = []
|
| 1514 |
+
aiofiles = []
|
| 1515 |
+
aiohttp = []
|
| 1516 |
+
aiosignal = []
|
| 1517 |
+
altair = []
|
| 1518 |
+
antlr4-python3-runtime = []
|
| 1519 |
+
anyio = []
|
| 1520 |
+
async-timeout = []
|
| 1521 |
+
attrs = []
|
| 1522 |
+
cachetools = []
|
| 1523 |
+
certifi = []
|
| 1524 |
+
charset-normalizer = []
|
| 1525 |
+
click = []
|
| 1526 |
+
colorama = []
|
| 1527 |
+
contourpy = []
|
| 1528 |
+
cycler = []
|
| 1529 |
+
earthengine-api = []
|
| 1530 |
+
ee-extra = []
|
| 1531 |
+
entrypoints = []
|
| 1532 |
+
fastapi = []
|
| 1533 |
+
ffmpy = []
|
| 1534 |
+
filelock = []
|
| 1535 |
+
fonttools = []
|
| 1536 |
+
frozenlist = []
|
| 1537 |
+
fsspec = []
|
| 1538 |
+
google-api-core = []
|
| 1539 |
+
google-api-python-client = []
|
| 1540 |
+
google-auth = []
|
| 1541 |
+
google-auth-httplib2 = []
|
| 1542 |
+
google-auth-oauthlib = []
|
| 1543 |
+
google-cloud-core = []
|
| 1544 |
+
google-cloud-storage = []
|
| 1545 |
+
google-crc32c = []
|
| 1546 |
+
google-resumable-media = []
|
| 1547 |
+
googleapis-common-protos = []
|
| 1548 |
+
gradio = []
|
| 1549 |
+
gradio-client = []
|
| 1550 |
+
grpcio = []
|
| 1551 |
+
h11 = []
|
| 1552 |
+
httpcore = []
|
| 1553 |
+
httplib2 = []
|
| 1554 |
+
httpx = []
|
| 1555 |
+
huggingface-hub = []
|
| 1556 |
+
hydra-client = []
|
| 1557 |
+
hydra-core = []
|
| 1558 |
+
idna = []
|
| 1559 |
+
jinja2 = []
|
| 1560 |
+
jsonschema = []
|
| 1561 |
+
kiwisolver = []
|
| 1562 |
+
lightning-utilities = []
|
| 1563 |
+
linkify-it-py = []
|
| 1564 |
+
markdown = []
|
| 1565 |
+
markdown-it-py = []
|
| 1566 |
+
markupsafe = []
|
| 1567 |
+
matplotlib = []
|
| 1568 |
+
mdit-py-plugins = []
|
| 1569 |
+
mdurl = []
|
| 1570 |
+
multidict = []
|
| 1571 |
+
numpy = []
|
| 1572 |
+
nvidia-cublas-cu11 = []
|
| 1573 |
+
nvidia-cuda-nvrtc-cu11 = []
|
| 1574 |
+
nvidia-cuda-runtime-cu11 = []
|
| 1575 |
+
nvidia-cudnn-cu11 = []
|
| 1576 |
+
oauthlib = []
|
| 1577 |
+
omegaconf = []
|
| 1578 |
+
opencv-python = []
|
| 1579 |
+
orjson = []
|
| 1580 |
+
packaging = []
|
| 1581 |
+
pandas = []
|
| 1582 |
+
pillow = []
|
| 1583 |
+
plotly = []
|
| 1584 |
+
protobuf = []
|
| 1585 |
+
pyasn1 = []
|
| 1586 |
+
pyasn1-modules = []
|
| 1587 |
+
pydantic = []
|
| 1588 |
+
pydub = []
|
| 1589 |
+
pyparsing = []
|
| 1590 |
+
pyrsistent = []
|
| 1591 |
+
python-dateutil = []
|
| 1592 |
+
python-multipart = []
|
| 1593 |
+
pytorch-lightning = []
|
| 1594 |
+
pytz = []
|
| 1595 |
+
pyyaml = []
|
| 1596 |
+
requests = []
|
| 1597 |
+
requests-oauthlib = []
|
| 1598 |
+
rsa = []
|
| 1599 |
+
scipy = []
|
| 1600 |
+
seaborn = []
|
| 1601 |
+
semantic-version = []
|
| 1602 |
+
setuptools-scm = []
|
| 1603 |
+
six = []
|
| 1604 |
+
sniffio = []
|
| 1605 |
+
starlette = []
|
| 1606 |
+
tenacity = []
|
| 1607 |
+
tensorboard = []
|
| 1608 |
+
tensorboard-data-server = []
|
| 1609 |
+
tensorboard-plugin-wit = []
|
| 1610 |
+
tomli = []
|
| 1611 |
+
toolz = []
|
| 1612 |
+
torch = []
|
| 1613 |
+
torchmetrics = []
|
| 1614 |
+
torchvision = []
|
| 1615 |
+
tqdm = []
|
| 1616 |
+
typing-extensions = []
|
| 1617 |
+
tzdata = []
|
| 1618 |
+
uc-micro-py = []
|
| 1619 |
+
uritemplate = []
|
| 1620 |
+
urllib3 = []
|
| 1621 |
+
uvicorn = []
|
| 1622 |
+
websockets = []
|
| 1623 |
+
werkzeug = []
|
| 1624 |
+
wget = []
|
| 1625 |
+
yarl = []
|
pyproject.toml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "cv_app"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = ""
|
| 5 |
+
authors = ["Your Name <[email protected]>"]
|
| 6 |
+
|
| 7 |
+
[tool.poetry.dependencies]
|
| 8 |
+
python = ">=3.10,<3.12"
|
| 9 |
+
torch = "1.13.1"
|
| 10 |
+
tensorboard = "2.11.2"
|
| 11 |
+
pytorch-lightning = "1.9.0"
|
| 12 |
+
torchmetrics = "0.11.0"
|
| 13 |
+
Pillow = "8.4.0"
|
| 14 |
+
torchvision = "0.14.1"
|
| 15 |
+
matplotlib = "^3.7.1"
|
| 16 |
+
hydra-client = "0.5.1"
|
| 17 |
+
hydra-core = "1.3.1"
|
| 18 |
+
wget = "^3.2"
|
| 19 |
+
scipy = "^1.10.1"
|
| 20 |
+
seaborn = "^0.12.2"
|
| 21 |
+
earthengine-api = "0.1.338"
|
| 22 |
+
ee-extra = "0.0.15"
|
| 23 |
+
gradio = "^3.27.0"
|
| 24 |
+
opencv-python = "^4.7.0"
|
| 25 |
+
plotly = "^5.14.1"
|
| 26 |
+
|
| 27 |
+
[tool.poetry.dev-dependencies]
|
| 28 |
+
|
| 29 |
+
[build-system]
|
| 30 |
+
requires = ["poetry-core>=1.0.0"]
|
| 31 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
aiohttp==3.8.3
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
antlr4-python3-runtime==4.9.3
|
| 5 |
+
appdirs==1.4.4
|
| 6 |
+
argh==0.26.2
|
| 7 |
+
async-timeout==4.0.2
|
| 8 |
+
atomicwrites==1.4.0
|
| 9 |
+
attrs==19.3.0
|
| 10 |
+
backports.weakref==1.0.post1
|
| 11 |
+
bkcharts==0.2
|
| 12 |
+
black==19.10b0
|
| 13 |
+
boto==2.49.0
|
| 14 |
+
bqplot==0.12.36
|
| 15 |
+
branca==0.6.0
|
| 16 |
+
brotlipy==0.7.0
|
| 17 |
+
cachetools==5.3.0
|
| 18 |
+
certifi==2021.10.8
|
| 19 |
+
click==8.0.3
|
| 20 |
+
colour==0.1.5
|
| 21 |
+
comtypes==1.1.10
|
| 22 |
+
cycler==0.10.0
|
| 23 |
+
cytoolz==0.11.0
|
| 24 |
+
daal4py==2021.3.0
|
| 25 |
+
dask==2021.10.0
|
| 26 |
+
earthengine-api==0.1.338
|
| 27 |
+
ee-extra==0.0.15
|
| 28 |
+
eerepr==0.0.4
|
| 29 |
+
entrypoints==0.3
|
| 30 |
+
et-xmlfile==1.1.0
|
| 31 |
+
export==0.2.0
|
| 32 |
+
ffmpeg-python==0.2.0
|
| 33 |
+
folium==0.14.0
|
| 34 |
+
fonttools==4.25.0
|
| 35 |
+
frozenlist==1.3.3
|
| 36 |
+
gdown==4.6.0
|
| 37 |
+
geeadd==0.5.6
|
| 38 |
+
geemap==0.19.6
|
| 39 |
+
geocoder==1.38.1
|
| 40 |
+
geojson==3.0.0
|
| 41 |
+
google-api-core==2.11.0
|
| 42 |
+
google-api-python-client==2.74.0
|
| 43 |
+
google-auth==2.16.0
|
| 44 |
+
google-auth-httplib2==0.1.0
|
| 45 |
+
google-auth-oauthlib==0.4.6
|
| 46 |
+
google-cloud-core==2.3.2
|
| 47 |
+
google-cloud-storage==2.7.0
|
| 48 |
+
google-crc32c==1.5.0
|
| 49 |
+
google-resumable-media==2.4.1
|
| 50 |
+
googleapis-common-protos==1.58.0
|
| 51 |
+
grpcio==1.51.1
|
| 52 |
+
httplib2==0.21.0
|
| 53 |
+
hydra-client==0.5.1
|
| 54 |
+
hydra-core==1.3.1
|
| 55 |
+
inflection==0.5.1
|
| 56 |
+
ipyevents==2.0.1
|
| 57 |
+
ipyfilechooser==0.6.0
|
| 58 |
+
ipyleaflet==0.17.2
|
| 59 |
+
ipytree==0.2.2
|
| 60 |
+
lightning-utilities==0.6.0.post0
|
| 61 |
+
llvmlite==0.37.0
|
| 62 |
+
locket==0.2.1
|
| 63 |
+
logzero==1.7.0
|
| 64 |
+
Markdown==3.4.1
|
| 65 |
+
mccabe==0.6.1
|
| 66 |
+
mkl-fft==1.3.1
|
| 67 |
+
mkl-service==2.4.0
|
| 68 |
+
mpmath==1.2.1
|
| 69 |
+
multidict==6.0.4
|
| 70 |
+
munkres==1.1.4
|
| 71 |
+
mypy-extensions==0.4.3
|
| 72 |
+
nltk==3.6.5
|
| 73 |
+
oauthlib==3.2.2
|
| 74 |
+
omegaconf==2.3.0
|
| 75 |
+
pathspec==0.7.0
|
| 76 |
+
patsy==0.5.2
|
| 77 |
+
pep8==1.7.1
|
| 78 |
+
Pillow==8.4.0
|
| 79 |
+
pkginfo==1.7.1
|
| 80 |
+
plotly==5.13.0
|
| 81 |
+
ply==3.11
|
| 82 |
+
protobuf==3.20.3
|
| 83 |
+
pyasn1==0.4.8
|
| 84 |
+
pyasn1-modules==0.2.8
|
| 85 |
+
pycosat==0.6.3
|
| 86 |
+
PyCRS==1.0.2
|
| 87 |
+
pycurl==7.44.1
|
| 88 |
+
pyls-spyder==0.4.0
|
| 89 |
+
pyperclip==1.8.2
|
| 90 |
+
pyreadline==2.1
|
| 91 |
+
pyshp==2.3.1
|
| 92 |
+
pytest==6.2.4
|
| 93 |
+
python-box==6.1.0
|
| 94 |
+
python-lsp-jsonrpc==1.0.0
|
| 95 |
+
python-lsp-server==1.2.4
|
| 96 |
+
pytorch-lightning==1.9.0
|
| 97 |
+
pytz==2021.3
|
| 98 |
+
PyYAML==6.0
|
| 99 |
+
ratelim==0.1.6
|
| 100 |
+
requests-oauthlib==1.3.1
|
| 101 |
+
rsa==4.9
|
| 102 |
+
sankee==0.2.1
|
| 103 |
+
scikit-image==0.18.3
|
| 104 |
+
scooby==0.7.1
|
| 105 |
+
simplegeneric==0.8.1
|
| 106 |
+
Sphinx==4.2.0
|
| 107 |
+
statsmodels==0.12.2
|
| 108 |
+
tables==3.6.1
|
| 109 |
+
tenacity==8.1.0
|
| 110 |
+
tensorboard==2.11.2
|
| 111 |
+
tensorboard-data-server==0.6.1
|
| 112 |
+
tensorboard-plugin-wit==1.8.1
|
| 113 |
+
terminado==0.9.4
|
| 114 |
+
torch==1.13.1
|
| 115 |
+
torchaudio==0.13.1
|
| 116 |
+
torchmetrics==0.11.0
|
| 117 |
+
torchvision==0.14.1
|
| 118 |
+
traittypes==0.2.1
|
| 119 |
+
typing_extensions==4.4.0
|
| 120 |
+
unicodecsv==0.14.1
|
| 121 |
+
uritemplate==4.1.1
|
| 122 |
+
urllib3==1.26.7
|
| 123 |
+
webencodings==0.5.1
|
| 124 |
+
wget==3.2
|
| 125 |
+
whitebox==2.2.0
|
| 126 |
+
whiteboxgui==2.2.0
|
| 127 |
+
win-unicode-console==0.5
|
| 128 |
+
wincertstore==0.2
|
| 129 |
+
xlwt==1.3.0
|
| 130 |
+
xyzservices==2022.9.0
|
| 131 |
+
yarl==1.8.2
|
| 132 |
+
zict==2.0.0
|
| 133 |
+
zope.event==4.5.0
|