Spaces:
Runtime error
Runtime error
| from datetime import datetime | |
| import ee | |
| from func_timeout import func_set_timeout | |
| import pandas as pd | |
| from PIL import Image | |
| import requests | |
| import tempfile | |
| import io | |
| from tqdm import tqdm | |
| import functools | |
| import re # Used in an eval statement | |
| from typing import List | |
| from typing import Union | |
| from typing import Any | |
| class DataLoader: | |
| """ | |
| Main class for loading and exploring data from satellite images. | |
| The goal is to load an ImageCollection and to filter that collection according to needs, with methods like | |
| filter, filterDate, filterBounds, select. These will work just like earth engine's methods with the same names. | |
| This class, just like earth engine, works with lazy loading and compute. This means that running filterBounds | |
| will not actually filter the image collection until required, e.g. when counting the images by accessing .count | |
| property. | |
| However, it will only load once the information it needs, unless additional filtering is made. | |
| This works thanks to the signal_change decorator. If you develop a new filtering method for this class, | |
| you will need to decorate your method with @signal_change. | |
| In addition, if you develop a new method that will require to run getInfo to actually load data from | |
| Google Earth Engine, you will need to use _get_timeout_info(your object before getInfo). This will run | |
| getInfo with a timeout (currently set to 10 seconds). | |
| It is important to use a timeout to avoid unexpected run times. | |
| Usage: | |
| >>> dl = DataLoader(satellite_name="COPERNICUS/S2_SR", \ | |
| start_date='2021-01-01', \ | |
| end_date='2021-01-15', \ | |
| bands=["TCI_R", "TCI_G", "TCI_B"], \ | |
| geographic_bounds=ee.Geometry.Point(*[5.238728194366604, 44.474864056855935]).buffer(500) \ | |
| ) | |
| Get a pandas dataframe with all pixel values as a timeseries: | |
| >>> dl.getRegion(dl.bounds, 500) | |
| >>> dl.region.head(2) | |
| [Out] | |
| id longitude latitude time B1 B2 B3 B4 B5 B6 ... WVP SCL TCI_R TCI_G TCI_B MSK_CLDPRB MSK_SNWPRB QA10 QA20 QA60 | |
| 0 20210102T104441_20210102T104435_T31TFK 5.234932 44.473344 2021-01-02 10:48:36.299 6297 5955 5768 5773 5965 5883 ... 393 8 255 255 255 0 95 0 0 1024 | |
| 1 20210104T103329_20210104T103331_T31TFK 5.234932 44.473344 2021-01-04 10:38:38.304 5547 5355 5184 5090 5254 5229 ... 314 9 255 255 255 29 9 0 0 1024 | |
| >>> dl.date_range | |
| [Out] | |
| {'max': datetime.datetime(2021, 1, 14, 11, 38, 39, 208000), | |
| 'min': datetime.datetime(2021, 1, 2, 11, 48, 36, 299000)} | |
| >>> dl.count | |
| [Out] | |
| 6 | |
| >>> dl.collection_info # constains a html description of the dataset in "description" | |
| >>> dl.image_ids | |
| [Out] | |
| ['COPERNICUS/S2_SR/20210102T104441_20210102T104435_T31TFK', | |
| 'COPERNICUS/S2_SR/20210104T103329_20210104T103331_T31TFK', | |
| 'COPERNICUS/S2_SR/20210107T104329_20210107T104328_T31TFK', | |
| 'COPERNICUS/S2_SR/20210109T103421_20210109T103431_T31TFK', | |
| 'COPERNICUS/S2_SR/20210112T104411_20210112T104438_T31TFK', | |
| 'COPERNICUS/S2_SR/20210114T103309_20210114T103305_T31TFK'] | |
| # Download the image | |
| >>> img = dl.download_image(dl.image_ids[3]) | |
| # Download all images as a list | |
| >>> imgs = dl.download_all_images(scale=1) | |
| """ | |
| def __init__(self, | |
| satellite_name: str, | |
| bands: Union[List, str] = None, | |
| start_date: str = None, | |
| end_date: str = None, | |
| geographic_bounds: ee.geometry = None, | |
| scale: int = 10, | |
| crs: str = "EPSG:32630" | |
| ): | |
| """ | |
| Args: | |
| satellite_name: satellite to use. Examples: COPERNICUS/S2_SR, COPERNICUS/CORINE/V20/100m. | |
| See https://developers.google.com/earth-engine/datasets for the full list. | |
| bands: list of bands to load. | |
| start_date: lowest possible date. Might be lower than the actual date of the first picture. | |
| end_date: Latest possible date. | |
| geographic_bounds: Region of interest. | |
| """ | |
| self.satellite_name = satellite_name | |
| if isinstance(bands, str): | |
| bands = [bands] | |
| self.bands = bands if bands is not None else list() | |
| if start_date is None or end_date is None: | |
| assert (start_date is not None) and (end_date is not None), "start_date and end_date must both be provided" | |
| self.start_date = start_date | |
| self.end_date = end_date | |
| self.bounds = geographic_bounds | |
| # Lazy computed | |
| self._available_images = None | |
| # Start getting info from google cloud | |
| if satellite_name: | |
| self.image_collection = ee.ImageCollection(self.satellite_name) | |
| if self.bounds: | |
| self.filterBounds(self.bounds) | |
| if self.start_date is not None: | |
| self.filterDate(self.start_date, self.end_date) | |
| self.scale = scale | |
| self.crs = crs | |
| self.image_list = None | |
| self._df_image_list = None | |
| self.image_collection_info = None | |
| self._date_range = None | |
| self.date_filter_change = False | |
| self._count = None | |
| # Bool for caching | |
| self.filter_change = True | |
| self._describe = None | |
| def signal_change(func): | |
| """Signals that additional filtering was performed. To be used | |
| as a decorator.""" | |
| def wrap(self, *args, **kwargs): | |
| self.filter_change = True | |
| self.date_filter_change = True | |
| return func(self, *args, **kwargs) | |
| return wrap | |
| def _get_timeout_info(instance: Any): | |
| """Runs getInfo on anything that is passed, with a timeout.""" | |
| return instance.getInfo() | |
| def _authenticate_gee(): | |
| """Authenticates earth engine if needed, and initializes.""" | |
| try: | |
| ee.Initialize() | |
| except Exception as e: | |
| # Trigger the authentication flow. | |
| ee.Authenticate() | |
| # Initialize the library. | |
| ee.Initialize() | |
| def filter(self, ee_filter: ee.Filter): | |
| """Applies a filter to the image_collection attribute. This can be useful for example | |
| to filter out clouds | |
| Args: | |
| ee_filter: Filter to apply, must be an instance of ee.Filter. | |
| Returns: self, for operation chaining as possible with the earth engine API. | |
| """ | |
| self.image_collection = self.image_collection.filter(ee_filter) | |
| return self | |
| def count(self): | |
| """Number of images in the ImageCollection""" | |
| if self.filter_change or self._count is None: | |
| self._count = self._get_timeout_info(self.image_collection.size()) | |
| self.filter_change = False | |
| return self._count | |
| def available_images(self): | |
| """Gets the ImageCollection info""" | |
| if self.filter_change or self._available_images is None: | |
| self._available_images = self._get_timeout_info(self.image_collection) | |
| return self._available_images | |
| def filterDate(self, *args, **kwargs): | |
| """Wrapper for the filterDate method in earth engine on the ImageCollection""" | |
| self.image_collection = self.image_collection.filterDate(*args, **kwargs) | |
| return self | |
| def getRegion(self, *args, **kwargs): | |
| """Wrapper for the getRegion method in earth engine on the ImageCollection. | |
| Caveat! getRegion does not return an image collection, so the image_list attribute gets | |
| updated instead of the image_collection attribute. However, the instance of the DataLoader class | |
| is still returned, so this could be chained with another method on ImageCollection, which wouldn't be | |
| possible using earth engine. | |
| """ | |
| self.image_list = self.image_collection.getRegion(*args, **kwargs) | |
| return self | |
| def filterBounds(self, geometry, *args, **kwargs): | |
| """Wrapper for the filterBounds method in earth engine on the ImageCollection""" | |
| self.image_collection = self.image_collection.filterBounds(geometry, *args, **kwargs) | |
| self.bounds = geometry | |
| return self | |
| def select(self, *bands, **kwargs): | |
| """Wrapper for the select method in earth engine on the ImageCollection""" | |
| self.image_collection = self.image_collection.select(*bands, **kwargs) | |
| self.bands = list(set(self.bands) | set(bands)) # Unique bands | |
| return self | |
| def date_range(self): | |
| """Gets the actual date range of the images in the image collection.""" | |
| if self.date_filter_change or self._date_range is None: | |
| date_range = self.image_collection.reduceColumns(ee.Reducer.minMax(), ["system:time_start"]).getInfo() | |
| self._date_range = {key: datetime.fromtimestamp(value/1e3) for key, value in date_range.items()} | |
| self.date_filter_change = False | |
| return self._date_range | |
| def region(self): | |
| """Gets a time series as a pandas DataFrame of the band values for the specified region.""" | |
| if self.filter_change: | |
| if self.image_list is None: | |
| self.getRegion() | |
| res_list = self._get_timeout_info(self.image_list) | |
| df = pd.DataFrame(res_list[1:], columns=res_list[0]) | |
| df.loc[:, "time"] = pd.to_datetime(df.loc[:, "time"], unit="ms") | |
| self._df_image_list = df | |
| self.filter_change = False | |
| return self._df_image_list | |
| def collection_info(self): | |
| """Runs getInfo on the image collection (the first time the next time the previously | |
| populated attribute will be returned).""" | |
| if self.count > 5000: | |
| raise Exception("Too many images to load. Try filtering more") | |
| if self.filter_change or self.image_collection_info is None: | |
| self.image_collection_info = self._get_timeout_info(self.image_collection) | |
| return self.image_collection_info | |
| def image_ids(self): | |
| """list of names of available images in the image collection""" | |
| return [i["id"] for i in self.collection_info["features"]] | |
| def __repr__(self): | |
| try: | |
| return f""" | |
| Size: {self.count} | |
| Dataset date ranges: | |
| From: {self.date_range["min"]} | |
| To: {self.date_range["max"]} | |
| Selected bands: | |
| {self.bands} | |
| """ | |
| except Exception as e: | |
| raise Exception("Impossible to represent the dataset. Try filtering more. Error handling to do.") | |
| def reproject(self, image, **kwargs): | |
| def resolve(name: str): | |
| # Resolve crs | |
| if name in kwargs: | |
| item = kwargs[name] | |
| elif getattr(self, name): | |
| item = getattr(self, name) | |
| else: | |
| item = None | |
| return item | |
| crs = resolve("crs") | |
| scale = resolve("scale") | |
| if crs is not None or scale is not None: | |
| image = image.reproject(crs, None, scale) | |
| return image | |
| def download_image(self, image_id: str, **kwargs): | |
| """Downloads an image based on its id / name. The additional arguments are passed | |
| to getThumbUrl, and could be scale, max, min... | |
| """ | |
| img = ee.Image(image_id).select(*self.bands) | |
| img = self.reproject(img, **kwargs) | |
| input_args = {'region': self.bounds} | |
| input_args.update(**kwargs) | |
| all_bands = self.collection_info["features"][0]["bands"] | |
| selected_bands = [band for i, band in enumerate(all_bands) if all_bands[i]["id"] in self.bands] | |
| if "min" not in input_args: | |
| input_args.update({"min": selected_bands[0]["data_type"]["min"]}) | |
| if "max" not in input_args: | |
| input_args.update({"max": selected_bands[0]["data_type"]["max"]}) | |
| url = img.getThumbUrl(input_args) | |
| buffer = tempfile.SpooledTemporaryFile(max_size=1e9) | |
| r = requests.get(url, stream=True) | |
| if r.status_code == 200: | |
| downloaded = 0 | |
| # filesize = int(r.headers['content-length']) | |
| for chunk in r.iter_content(chunk_size=1024): | |
| downloaded += len(chunk) | |
| buffer.write(chunk) | |
| buffer.seek(0) | |
| img = Image.open(io.BytesIO(buffer.read())) | |
| buffer.close() | |
| return img | |
| def _regex(regex: str, im_id_list: List[str], include: bool) -> list: | |
| """ | |
| Filters the im_id_list based on a regular expression. This is useful before downloading | |
| a collection of images. For example, using (.*)TXT with include=True will only download images | |
| that end with TXT, wich for Nantes means filtering out empty or half empty images. | |
| Args: | |
| regex: python regex as a strng | |
| im_id_list: list, image id list | |
| include: whether to include or exclude elements that match the regex. | |
| Returns: filtered list. | |
| """ | |
| expression = "re.match('{regex}', '{im_id}') is not None" | |
| if not include: | |
| expression = "not " + expression | |
| filtered_list = list() | |
| for im_id in im_id_list: | |
| if eval(expression.format(regex=regex, im_id=im_id)): | |
| filtered_list.append(im_id) | |
| return filtered_list | |
| def download_all_images(self, regex_exclude: str = None, regex_include: str = None, **kwargs): | |
| """ | |
| Runs download_image in a for loop around the available images. | |
| Makes it possible to filter images to download based on a regex. | |
| Args: | |
| regex_exclude: any image that matches this regex will be excluded. | |
| regex_include: any image that matches this regex will be included | |
| **kwargs: arguments to be passed to getThumbUrl | |
| Returns: list of PIL images | |
| """ | |
| images = list() | |
| image_ids = self.image_ids | |
| if regex_exclude is not None: | |
| image_ids = self._regex(regex_exclude, image_ids, include=False) | |
| if regex_include is not None: | |
| image_ids = self._regex(regex_include, image_ids, include=True) | |
| for i in tqdm(range(len(image_ids))): | |
| images.append(self.download_image(image_ids[i], **kwargs)) | |
| return images | |