Spaces:
Sleeping
Sleeping
| ## Daniel Buscombe, Marda Science LLC 2023 | |
| # This file contains many functions originally from Doodleverse https://github.com/Doodleverse programs | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| from skimage.transform import resize | |
| from skimage.io import imsave, imread | |
| from skimage.filters import threshold_otsu | |
| # from skimage.measure import EllipseModel, CircleModel, ransac | |
| from glob import glob | |
| import json | |
| from transformers import TFSegformerForSemanticSegmentation | |
| ##======================================================== | |
| def segformer( | |
| id2label, | |
| num_classes=2, | |
| ): | |
| """ | |
| https://keras.io/examples/vision/segformer/ | |
| https://huggingface.co/nvidia/mit-b0 | |
| """ | |
| label2id = {label: id for id, label in id2label.items()} | |
| model_checkpoint = "nvidia/mit-b0" | |
| model = TFSegformerForSemanticSegmentation.from_pretrained( | |
| model_checkpoint, | |
| num_labels=num_classes, | |
| id2label=id2label, | |
| label2id=label2id, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| return model | |
| ##======================================================== | |
| def fromhex(n): | |
| """hexadecimal to integer""" | |
| return int(n, base=16) | |
| ##======================================================== | |
| def label_to_colors( | |
| img, | |
| mask, | |
| alpha, # =128, | |
| colormap, # =class_label_colormap, #px.colors.qualitative.G10, | |
| color_class_offset, # =0, | |
| do_alpha, # =True | |
| ): | |
| """ | |
| Take MxN matrix containing integers representing labels and return an MxNx4 | |
| matrix where each label has been replaced by a color looked up in colormap. | |
| colormap entries must be strings like plotly.express style colormaps. | |
| alpha is the value of the 4th channel | |
| color_class_offset allows adding a value to the color class index to force | |
| use of a particular range of colors in the colormap. This is useful for | |
| example if 0 means 'no class' but we want the color of class 1 to be | |
| colormap[0]. | |
| """ | |
| colormap = [ | |
| tuple([fromhex(h[s : s + 2]) for s in range(0, len(h), 2)]) | |
| for h in [c.replace("#", "") for c in colormap] | |
| ] | |
| cimg = np.zeros(img.shape[:2] + (3,), dtype="uint8") | |
| minc = np.min(img) | |
| maxc = np.max(img) | |
| for c in range(minc, maxc + 1): | |
| cimg[img == c] = colormap[(c + color_class_offset) % len(colormap)] | |
| cimg[mask == 1] = (0, 0, 0) | |
| if do_alpha is True: | |
| return np.concatenate( | |
| (cimg, alpha * np.ones(img.shape[:2] + (1,), dtype="uint8")), axis=2 | |
| ) | |
| else: | |
| return cimg | |
| ##==================================== | |
| def standardize(img): | |
| # standardization using adjusted standard deviation | |
| N = np.shape(img)[0] * np.shape(img)[1] | |
| s = np.maximum(np.std(img), 1.0 / np.sqrt(N)) | |
| m = np.mean(img) | |
| img = (img - m) / s | |
| del m, s, N | |
| # | |
| if np.ndim(img) == 2: | |
| img = np.dstack((img, img, img)) | |
| return img | |
| ############################################################ | |
| ############################################################ | |
| #load model | |
| filepath = './weights/ct_NAIP_8class_768_segformer_v3_fullmodel.h5' | |
| configfile = filepath.replace('_fullmodel.h5','.json') | |
| with open(configfile) as f: | |
| config = json.load(f) | |
| # This is how the program is able to use variables that have never been explicitly defined | |
| for k in config.keys(): | |
| exec(k+'=config["'+k+'"]') | |
| id2label = {} | |
| for k in range(NCLASSES): | |
| id2label[k]=str(k) | |
| model = segformer(id2label,num_classes=NCLASSES) | |
| # model.compile(optimizer='adam') | |
| model.load_weights(filepath) | |
| ############################################################ | |
| ############################################################ | |
| # #----------------------------------- | |
| def est_label_multiclass(image,Mc,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE): | |
| est_label = np.zeros((TARGET_SIZE[0], TARGET_SIZE[1], NCLASSES)) | |
| for counter, model in enumerate(Mc): | |
| # heatmap = make_gradcam_heatmap(tf.expand_dims(image, 0) , model) | |
| try: | |
| if MODEL=='segformer': | |
| est_label = model(tf.expand_dims(image, 0)).logits | |
| else: | |
| est_label = tf.squeeze(model(tf.expand_dims(image, 0))) | |
| except: | |
| if MODEL=='segformer': | |
| est_label = model(tf.expand_dims(image[:,:,0], 0)).logits | |
| else: | |
| est_label = tf.squeeze(model(tf.expand_dims(image[:,:,0], 0))) | |
| if TESTTIMEAUG == True: | |
| # return the flipped prediction | |
| if MODEL=='segformer': | |
| est_label2 = np.flipud( | |
| model(tf.expand_dims(np.flipud(image), 0)).logits | |
| ) | |
| else: | |
| est_label2 = np.flipud( | |
| tf.squeeze(model(tf.expand_dims(np.flipud(image), 0))) | |
| ) | |
| if MODEL=='segformer': | |
| est_label3 = np.fliplr( | |
| model( | |
| tf.expand_dims(np.fliplr(image), 0)).logits | |
| ) | |
| else: | |
| est_label3 = np.fliplr( | |
| tf.squeeze(model(tf.expand_dims(np.fliplr(image), 0))) | |
| ) | |
| if MODEL=='segformer': | |
| est_label4 = np.flipud( | |
| np.fliplr( | |
| tf.squeeze(model(tf.expand_dims(np.flipud(np.fliplr(image)), 0)).logits)) | |
| ) | |
| else: | |
| est_label4 = np.flipud( | |
| np.fliplr( | |
| tf.squeeze(model( | |
| tf.expand_dims(np.flipud(np.fliplr(image)), 0))) | |
| )) | |
| # soft voting - sum the softmax scores to return the new TTA estimated softmax scores | |
| est_label = est_label + est_label2 + est_label3 + est_label4 | |
| return est_label, counter | |
| # #----------------------------------- | |
| def seg_file2tensor_3band(bigimage, TARGET_SIZE): | |
| """ | |
| "seg_file2tensor(f)" | |
| This function reads a jpeg image from file into a cropped and resized tensor, | |
| for use in prediction with a trained segmentation model | |
| INPUTS: | |
| * f [string] file name of jpeg | |
| OPTIONAL INPUTS: None | |
| OUTPUTS: | |
| * image [tensor array]: unstandardized image | |
| GLOBAL INPUTS: TARGET_SIZE | |
| """ | |
| smallimage = resize( | |
| bigimage, (TARGET_SIZE[0], TARGET_SIZE[1]), preserve_range=True, clip=True | |
| ) | |
| smallimage = np.array(smallimage) | |
| smallimage = tf.cast(smallimage, tf.uint8) | |
| w = tf.shape(bigimage)[0] | |
| h = tf.shape(bigimage)[1] | |
| return smallimage, w, h, bigimage | |
| # #----------------------------------- | |
| def get_image(f,N_DATA_BANDS,TARGET_SIZE,MODEL): | |
| image, w, h, bigimage = seg_file2tensor_3band(f, TARGET_SIZE) | |
| image = standardize(image.numpy()).squeeze() | |
| if MODEL=='segformer': | |
| if np.ndim(image)==2: | |
| image = np.dstack((image, image, image)) | |
| image = tf.transpose(image, (2, 0, 1)) | |
| return image, w, h, bigimage | |
| # #----------------------------------- | |
| #segmentation | |
| def segment(input_img, use_tta, use_otsu, dims=(768, 768)): | |
| if use_otsu: | |
| print("Use Otsu threshold") | |
| else: | |
| print("No Otsu threshold") | |
| if use_tta: | |
| print("Use TTA") | |
| else: | |
| print("Do not use TTA") | |
| image, w, h, bigimage = get_image(input_img,N_DATA_BANDS,TARGET_SIZE,MODEL) | |
| est_label, counter = est_label_multiclass(image,[model],'segformer',TESTTIMEAUG,NCLASSES,TARGET_SIZE) | |
| print(est_label.shape) | |
| est_label /= counter + 1 | |
| # est_label cannot be float16 so convert to float32 | |
| est_label = est_label.numpy().astype('float32') | |
| est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze() | |
| est_label = np.transpose(est_label, (1,2,0)) | |
| est_label = resize(est_label, (w, h)) | |
| est_label = np.argmax(est_label,-1) | |
| print(est_label.shape) | |
| imsave("greyscale_download_me.png", est_label.astype('uint8')) | |
| class_label_colormap = [ | |
| "#3366CC", | |
| "#DC3912", | |
| "#FF9900", | |
| "#109618", | |
| "#990099", | |
| "#0099C6", | |
| "#DD4477", | |
| "#66AA00", | |
| "#B82E2E", | |
| "#316395", | |
| ] | |
| # add classes | |
| class_label_colormap = class_label_colormap[:NCLASSES] | |
| color_label = label_to_colors( | |
| est_label, | |
| input_img[:, :, 0] == 0, | |
| alpha=128, | |
| colormap=class_label_colormap, | |
| color_class_offset=0, | |
| do_alpha=False, | |
| ) | |
| imsave("color_download_me.png", color_label) | |
| return color_label,"greyscale_download_me.png", "color_download_me.png" | |
| title = "Mapping sand in high-res. imagery" | |
| description = "This simple model demonstration segments NAIP RGB (visible spectrum) imagery into the following classes:1. water (unbroken water); 2. whitewater (surf, active wave breaking); 3. sediment (natural deposits of sand. gravel, mud, etc), 4. other_bare_natural_terrain, 5. marsh_vegetation, 6. terrestrial_vegetation, 7. agricultural, 8. development. Please note that, ordinarily, ensemble models are used in predictive mode. Here, we are using just one model, i.e. without ensembling. Allows upload of 3-band imagery in jpg format and download of label imagery only one at a time. " | |
| examples= [[l] for l in glob('examples/*.jpg')] | |
| inp = gr.Image() | |
| out1 = gr.Image(type='numpy') | |
| # out2 = gr.Plot(type='matplotlib') | |
| out3 = gr.File() | |
| out4 = gr.File() | |
| inp2 = gr.inputs.Checkbox(default=False, label="Use TTA") | |
| inp3 = gr.inputs.Checkbox(default=False, label="Use Otsu") | |
| Segapp = gr.Interface(segment, [inp, inp2, inp3], | |
| [out1, out3, out4], #out2 | |
| title = title, description = description, examples=examples, | |
| theme="grass") | |
| Segapp.launch(enable_queue=True) |