Spaces:
Runtime error
Runtime error
| import requests | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import geopandas as gpd | |
| from pyproj.transformer import Transformer | |
| import cv2 | |
| import asyncio | |
| from matplotlib import patches as mpatches | |
| from matplotlib import gridspec | |
| sys.path.append(os.path.dirname(os.path.realpath(__file__))) | |
| from MapItAnywhere.mia.bev import get_bev | |
| from MapItAnywhere.mia.fpv import get_fpv | |
| from MapItAnywhere.mia.fpv import filters | |
| from MapItAnywhere.mia import logger | |
| def get_city_boundary(query, fetch_shape=False): | |
| # Use Nominatim API to get the boundary of the city | |
| base_url = "https://nominatim.openstreetmap.org/search" | |
| params = { | |
| 'q': query, | |
| 'format': 'json', | |
| 'limit': 1, | |
| 'polygon_geojson': 1 if fetch_shape else 0 | |
| } | |
| headers = { | |
| 'User-Agent': f'mapperceptionnet_{query}' | |
| } | |
| response = requests.get(base_url, params=params, headers=headers) | |
| if response.status_code != 200: | |
| logger.error(f"Nominatim error when fetching boundary data for {query}.\n" | |
| f"Status code: {response.status_code}. Content: {response.content}") | |
| return None | |
| data = response.json() | |
| if data is None: | |
| logger.warn(f"No data returned by Nominatim for {query}") | |
| return None | |
| # Extract bbox data from the API response | |
| bbox_data = data[0]['boundingbox'] | |
| bbox = { | |
| 'west': float(bbox_data[2]), | |
| 'south': float(bbox_data[0]), | |
| 'east': float(bbox_data[3]), | |
| 'north': float(bbox_data[1]) | |
| } | |
| if fetch_shape: | |
| # Extract GeoJSON boundary data from the API response | |
| boundary_geojson = data[0]['geojson'] | |
| boundary_geojson = { | |
| "type": "FeatureCollection", | |
| "features": [ | |
| {"type": "Feature", | |
| "properties": {}, | |
| "geometry": boundary_geojson}] | |
| } | |
| return bbox, boundary_geojson | |
| else: | |
| return bbox | |
| def split_dataframe(df, chunk_size = 100): | |
| chunks = list() | |
| num_chunks = len(df) // chunk_size + 1 | |
| for i in range(num_chunks): | |
| chunks.append(df[i*chunk_size:(i+1)*chunk_size]) | |
| return chunks | |
| downloader = get_fpv.MapillaryDownloader(os.getenv("MLY_TOKEN")) | |
| loop = asyncio.get_event_loop() | |
| def generate_error_plot(error_message): | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, error_message, fontsize=12, va='center', ha='center', wrap=True) | |
| ax.axis('off') | |
| fig_img_path = 'fpv_bev.png' | |
| fig.savefig(fig_img_path) | |
| fig_img = plt.imread(fig_img_path) | |
| return fig_img | |
| def fetch(location, num_images, filter_undistort, disable_cam_filter, map_length, mpp): | |
| TOTAL_LOOKED_INTO_LIMIT = 10000 | |
| ################ FPV | |
| bbox = get_city_boundary(query=location) | |
| tiles = get_fpv.get_tiles_from_boundary(boundary_info=dict(bound_type="auto_bbox", bbox=bbox), zoom=14) | |
| np.random.shuffle(tiles) | |
| total_looked_into = 0 | |
| dfs_meta = list() | |
| for tile in tiles: | |
| image_points_response = loop.run_until_complete(downloader.get_tiles_image_points([tile])) | |
| if image_points_response is None: | |
| continue | |
| try: | |
| df = get_fpv.parse_image_points_json_data(image_points_response) | |
| if len(df) == 0: | |
| continue | |
| total_looked_into += len(df) | |
| df_split = split_dataframe(df, chunk_size=100) | |
| for df in df_split: | |
| image_ids = df["id"] | |
| image_infos, num_fail = loop.run_until_complete(get_fpv.fetch_image_infos(image_ids, downloader, infos_dir)) | |
| df_meta = get_fpv.geojson_feature_list_to_pandas(image_infos.values()) | |
| # Some standardization of the data | |
| df_meta["model"] = df_meta["model"].str.lower().str.replace(' ', '').str.replace('_', '') | |
| df_meta["make"] = df_meta["make"].str.lower().str.replace(' ', '').str.replace('_', '') | |
| if filter_undistort: | |
| fp = no_cam_filter_pipeline if disable_cam_filter else filter_pipeline | |
| df_meta = fp(df_meta) | |
| dfs_meta.append(df_meta) | |
| total_rows = sum([len(x) for x in dfs_meta]) | |
| if total_rows > num_images: | |
| break | |
| elif total_looked_into > TOTAL_LOOKED_INTO_LIMIT: | |
| return generate_error_plot(f"Went through {total_looked_into} images and could not find images satisfying the filters." | |
| "\nPlease rerun or run the data engine locally for bulk time consuming operations.") | |
| if total_rows > num_images: | |
| break | |
| except: | |
| pass | |
| df_meta = pd.concat(dfs_meta) | |
| df_meta = df_meta.sample(num_images) | |
| # Calc derrivative attributes | |
| df_meta["loc_discrepancy"] = filters.haversine_np( | |
| lon1=df_meta["geometry.long"], lat1=df_meta["geometry.lat"], | |
| lon2=df_meta["computed_geometry.long"], lat2=df_meta["computed_geometry.lat"] | |
| ) | |
| df_meta["angle_discrepancy"] = filters.angle_dist( | |
| df_meta["compass_angle"], | |
| df_meta["computed_compass_angle"] | |
| ) | |
| img_list_to_show = list() | |
| for index, row in df_meta.iterrows(): | |
| print("Processing image", row["id"]) | |
| desc = list() | |
| # Display attributes | |
| keys = ["id", "geometry.long", "geometry.lat", "compass_angle", | |
| "loc_discrepancy", "angle_discrepancy", | |
| "make", "model", "camera_type", | |
| "quality_score"] | |
| for k in keys: | |
| v = row[k] | |
| if isinstance(v, float): | |
| v = f"{v:.4f}" | |
| bullet = f"{k}: {v}" | |
| desc.append(bullet) | |
| metadata_fmt = "\n".join(desc) | |
| # yield metadata_fmt, None, None | |
| image_urls = list(df_meta.set_index("id")["thumb_2048_url"].items()) | |
| num_fail = loop.run_until_complete(get_fpv.fetch_images_pixels(image_urls, downloader, raw_image_dir)) | |
| if num_fail > 0: | |
| logger.error(f"Failed to download {num_fail} images.") | |
| seq_to_image_ids = df_meta.groupby('sequence')['id'].agg(list).to_dict() | |
| lon_center = (bbox['east'] + bbox['west']) / 2 | |
| lat_center = (bbox['north'] + bbox['south']) / 2 | |
| projection = get_fpv.Projection(lat_center, lon_center, max_extent=200e3) | |
| df_meta.index = df_meta["id"] | |
| image_infos = df_meta.to_dict(orient="index") | |
| process_sequence_args = get_fpv.default_cfg | |
| if filter_undistort: | |
| for seq_id, seq_image_ids in seq_to_image_ids.items(): | |
| try: | |
| d, pi = get_fpv.process_sequence( | |
| seq_image_ids, | |
| image_infos, | |
| projection, | |
| process_sequence_args, | |
| raw_image_dir, | |
| out_image_dir, | |
| ) | |
| if d is None or pi is None: | |
| raise Exception("process_sequence returned None") | |
| except Exception as e: | |
| logger.error(f"Failed to process sequence {seq_id} skipping it. Error: {repr(e)}.") | |
| fpv = plt.imread(out_image_dir/ f"{row['id']}_undistorted.jpg") | |
| else: | |
| print("Loading raw image") | |
| fpv = plt.imread(raw_image_dir/ f"{row['id']}.jpg") | |
| # yield metadata_fmt, fpv, None | |
| ################ BEV | |
| df = df_meta | |
| # convert pandas dataframe to geopandas dataframe | |
| gdf = gpd.GeoDataFrame(df, | |
| geometry=gpd.points_from_xy( | |
| df['computed_geometry.long'], | |
| df['computed_geometry.lat']), | |
| crs=4326) | |
| # convert the geopandas dataframe to UTM | |
| utm_crs = gdf.estimate_utm_crs() | |
| gdf_utm = gdf.to_crs(utm_crs) | |
| transformer = Transformer.from_crs(utm_crs, 4326) | |
| # load OSM data, if available | |
| padding = 50 | |
| # calculate the required distance from the center to the edge of the image | |
| # so that the image will not be out of bounds when we rotate it | |
| map_length = map_length | |
| map_length = np.ceil(np.sqrt(map_length**2 + map_length**2)) | |
| distance = map_length * mpp | |
| # create bounding boxes for each point | |
| gdf_utm['bounding_box_utm_p1'] = gdf_utm.apply(lambda row: ( | |
| row.geometry.x - distance - padding, | |
| row.geometry.y - distance - padding, | |
| ), axis=1) | |
| gdf_utm['bounding_box_utm_p2'] = gdf_utm.apply(lambda row: ( | |
| row.geometry.x + distance + padding, | |
| row.geometry.y + distance + padding, | |
| ), axis=1) | |
| # convert the bounding box back to lat, long | |
| gdf_utm['bounding_box_lat_long_p1'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p1']), axis=1) | |
| gdf_utm['bounding_box_lat_long_p2'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p2']), axis=1) | |
| gdf_utm['bbox_min_lat'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[0]) | |
| gdf_utm['bbox_min_long'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[1]) | |
| gdf_utm['bbox_max_lat'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[0]) | |
| gdf_utm['bbox_max_long'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[1]) | |
| gdf_utm['bbox_formatted'] = gdf_utm.apply(lambda row: f"{row['bbox_min_long']},{row['bbox_min_lat']},{row['bbox_max_long']},{row['bbox_max_lat']}", axis=1) | |
| # iterate over the dataframe and get BEV images | |
| jobs = gdf_utm[['id', 'bbox_formatted', 'computed_compass_angle']] # only need the id and bbox_formatted columns for the jobs | |
| jobs = jobs.to_dict(orient='records').copy() | |
| get_bev.get_bev_from_bbox_worker_init(osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, | |
| "MapItAnywhere/mia/bev/styles/mia.yml", map_length, mpp, | |
| None, True, False, True, True, 1) | |
| for job_dict in jobs: | |
| get_bev.get_bev_from_bbox_worker(job_dict) | |
| bev = cv2.imread(rendered_mask_dir / f"{row['id']}.png") | |
| bev = cv2.cvtColor(bev, cv2.COLOR_BGR2RGB) | |
| print("BEV shape", bev.shape) | |
| img_list_to_show_i = [fpv, bev, metadata_fmt] | |
| img_list_to_show.append(img_list_to_show_i) | |
| # Make plt figure | |
| plt_row = len(img_list_to_show) | |
| print("plt_row", plt_row) | |
| plt_col = 3 | |
| for i in range(plt_row): | |
| fpv, bev, metadata_fmt = img_list_to_show[i] | |
| if i == 0: | |
| imgs = [fpv, bev] | |
| ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H | |
| ratios.append(0.5) # Metadata | |
| figsize = [sum(ratios) * 4.5, 4.5 * plt_row] | |
| dpi = 100 | |
| fig, ax = plt.subplots( | |
| plt_row, plt_col, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} | |
| ) | |
| # Plot FPV image | |
| if plt_row == 1: | |
| ax0 = ax[0] | |
| ax1 = ax[1] | |
| ax2 = ax[2] | |
| else: | |
| ax0 = ax[i, 0] | |
| ax1 = ax[i, 1] | |
| ax2 = ax[i, 2] | |
| ax0.imshow(fpv) | |
| ax0.set_title("First Person View Image") | |
| ax0.axis('off') | |
| # Plot BEV image | |
| ax1.imshow(bev) | |
| # Put a white upward triangle at the center of the image | |
| ax1.scatter(bev.shape[1]//2, bev.shape[0]//2, s=200, c='white', marker='^', edgecolors='black') | |
| ax1.set_title("Bird's Eye View Map") | |
| ax1.axis('off') | |
| # Add legend to BEV image | |
| class_colors = { | |
| 'Road': (68, 68, 68), # 0: Black | |
| 'Crossing': (244, 162, 97), # 1; Red | |
| 'Sidewalk': (233, 196, 106), # 2: Yellow | |
| 'Building': (231, 111, 81), # 5: Magenta | |
| 'Terrain': (42, 157, 143), # 7: Cyan | |
| 'Parking': (204, 204, 204), # 8: Dark Grey | |
| } | |
| patches = [mpatches.Patch(color=[c/255.0 for c in color], label=label) for label, color in class_colors.items()] | |
| ax1.legend(handles=patches, loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=3) | |
| # Plot metadata text | |
| ax2.axis('off') | |
| ax2.text(0.1, 0.5, metadata_fmt, fontsize=12, va='center', ha='left', wrap=True) | |
| ax2.set_title("Metadata") | |
| plt.tight_layout(pad=2.0) | |
| # Save figure and then read | |
| fig_img_path = 'fpv_bev.png' | |
| fig.savefig(fig_img_path) | |
| fig_img = plt.imread(fig_img_path) | |
| return fig_img | |
| filter_pipeline = filters.FilterPipeline.load_from_yaml("MapItAnywhere/mia/fpv/filter_pipelines/mia.yaml") | |
| filter_pipeline.verbose=False | |
| no_cam_filter_pipeline = filters.FilterPipeline.load_from_yaml("MapItAnywhere/mia/fpv/filter_pipelines/mia_rural.yaml") | |
| no_cam_filter_pipeline.verbose=False | |
| loc = Path(".") | |
| infos_dir =loc / "infos_dir" | |
| raw_image_dir = loc / "raw_images" | |
| out_image_dir = loc / "images" | |
| osm_cache_dir = loc / "osm_cache" | |
| bev_dir = loc / "bev_raw" | |
| semantic_mask_dir = loc / "semantic_masks" | |
| rendered_mask_dir = loc / "rendered_semantic_masks" | |
| all_dirs = [loc, osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, out_image_dir, raw_image_dir] | |
| for d in all_dirs: | |
| os.makedirs(d, exist_ok=True) | |
| logger.info(f"Current working directory: {os.getcwd()}, listdir: {os.listdir('.')}") | |
| description = """ | |
| <h2><center> <a href="https://mapitanywhere.github.io" target="_blank">Project Page</a> | <a href="https://github.com/MapItAnywhere/MapItAnywhere" target="_blank">Repository</a> \nUse our Data Engine to sample first-person view images and bird's-eye view semantic map pairs from locations worldwide. Simply pick a location to see the results! <center></h2> | |
| </h3><h3 align="center">Please note that the Huggingface demo runs much slower than running locally. If the curation takes longer than 1 minute, please restart the space (see the dropdown menu at the top-right of the page). For faster bulk downloads and more stringent filtering, visit our repository and follow the data engine instructions to run the data curation locally.</h3> | |
| """ | |
| demo = gr.Interface( | |
| fn=fetch, | |
| inputs=[gr.Text("Pittsburgh, PA, United States", label="Location (City, {Optional: State,} Country)"), | |
| gr.Number(value=1, label="Number of Data Pairs to Generate (Max: 3)", minimum=1, maximum=3), | |
| gr.Checkbox(value=False, label="Filter out images with high pose discrepancy (Enabled in paper. Results in better robot position estimate, but slower.)"), | |
| gr.Checkbox(value=False, label="Disable camera model filtering (Enabled in paper. Results in better quality labels, but slower.)"), | |
| gr.Slider(minimum=64, maximum=512, step=1, label="BEV Dimension", value=224), | |
| gr.Slider(minimum=0.1, maximum=2, label="Meters Per Pixel", value=0.5)], | |
| outputs=[gr.Image(label="Data Pair")], | |
| title="MapItAnywhere (MIA) Data Engine", | |
| description=description, | |
| ) | |
| logger.info("Starting server") | |
| demo.launch(server_name="0.0.0.0", server_port=7860,share=False) |