Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		artelabsuper
		
	commited on
		
		
					Commit 
							
							·
						
						99ee6d2
	
1
								Parent(s):
							
							7d56262
								
cached models improve speed
Browse files
    	
        app.py
    CHANGED
    
    | @@ -8,38 +8,41 @@ from matplotlib import colors | |
| 8 |  | 
| 9 | 
             
            if not hasattr(st, 'paths'):
         | 
| 10 | 
             
                st.paths = None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 |  | 
| 12 | 
             
            # Load Model
         | 
| 13 | 
             
            # @title Load pretrained weights
         | 
| 14 |  | 
| 15 | 
            -
             | 
| 16 | 
            -
            best_model_annual_file_name = "best_model_annual.pth"
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            first_input_batch = torch.zeros(71, 9, 5, 48, 48)
         | 
| 19 | 
            -
            # first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
         | 
| 20 | 
            -
            daily_model = FPN(opt, first_input_batch, opt.win_size)
         | 
| 21 | 
            -
            annual_model = SimpleNN(opt)
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            if torch.cuda.is_available():
         | 
| 24 | 
            -
                daily_model = torch.nn.DataParallel(daily_model).cuda()
         | 
| 25 | 
            -
                annual_model = torch.nn.DataParallel(annual_model).cuda()
         | 
| 26 | 
            -
                daily_model = torch.nn.DataParallel(daily_model).cuda()
         | 
| 27 | 
            -
                annual_model = torch.nn.DataParallel(annual_model).cuda()
         | 
| 28 | 
            -
            else:
         | 
| 29 | 
            -
                daily_model = torch.nn.DataParallel(daily_model).cpu()
         | 
| 30 | 
            -
                annual_model = torch.nn.DataParallel(annual_model).cpu()
         | 
| 31 | 
            -
                daily_model = torch.nn.DataParallel(daily_model).cpu()
         | 
| 32 | 
            -
                annual_model = torch.nn.DataParallel(annual_model).cpu()
         | 
| 33 | 
            -
             | 
| 34 | 
            -
            print('trying to resume previous saved models...')
         | 
| 35 | 
            -
            state = resume(
         | 
| 36 | 
            -
                os.path.join(opt.resume_path, best_model_daily_file_name),
         | 
| 37 | 
            -
                model=daily_model, optimizer=None)
         | 
| 38 | 
            -
            state = resume(
         | 
| 39 | 
            -
                os.path.join(opt.resume_path, best_model_annual_file_name),
         | 
| 40 | 
            -
                model=annual_model, optimizer=None)
         | 
| 41 | 
            -
            daily_model = daily_model.eval()
         | 
| 42 | 
            -
            annual_model = annual_model.eval()
         | 
| 43 |  | 
| 44 | 
             
            st.title('Lombardia Sentinel 2 daily Crop Mapping')
         | 
| 45 | 
             
            st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
         | 
| @@ -85,14 +88,14 @@ if sample_path is not None: | |
| 85 | 
             
                        if torch.cuda.is_available():
         | 
| 86 | 
             
                            x_dailies = x_dailies.cuda()
         | 
| 87 |  | 
| 88 | 
            -
                        feat_daily, outs_daily = daily_model.forward(x_dailies)
         | 
| 89 | 
             
                        # return to original size of batch and year
         | 
| 90 | 
             
                        outs_daily = outs_daily.view(
         | 
| 91 | 
             
                            opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
         | 
| 92 | 
             
                        feat_daily = feat_daily.view(
         | 
| 93 | 
             
                            opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
         | 
| 94 |  | 
| 95 | 
            -
                        _, out_annual = annual_model.forward(feat_daily)
         | 
| 96 | 
             
                        pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
         | 
| 97 | 
             
                        pred_annual = pred_annual.cpu().numpy()
         | 
| 98 | 
             
                        # Remapping the labels
         | 
| @@ -158,7 +161,7 @@ if st.paths is not None: | |
| 158 | 
             
                                           st.paths, index=st.paths.index('patch-pred-nn.tif'))
         | 
| 159 |  | 
| 160 | 
             
                file_path = os.path.join(folder, file_picker)
         | 
| 161 | 
            -
                print(file_path)
         | 
| 162 | 
             
                target, profile = read(file_path)
         | 
| 163 | 
             
                target = np.squeeze(target)
         | 
| 164 | 
             
                target = [classes_color_map[p] for p in target]
         | 
| @@ -169,7 +172,7 @@ if st.paths is not None: | |
| 169 |  | 
| 170 | 
             
                markdown_legend = ''
         | 
| 171 | 
             
                for c, l in zip(classes_color_map, labels_map):
         | 
| 172 | 
            -
                    print(colors.to_hex(c))
         | 
| 173 | 
             
                    markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
         | 
| 174 |  | 
| 175 | 
             
                col1, col2 = st.columns(2)
         | 
|  | |
| 8 |  | 
| 9 | 
             
            if not hasattr(st, 'paths'):
         | 
| 10 | 
             
                st.paths = None
         | 
| 11 | 
            +
            if not hasattr(st, 'daily_model'):
         | 
| 12 | 
            +
                best_model_daily_file_name = "best_model_daily.pth"
         | 
| 13 | 
            +
                best_model_annual_file_name = "best_model_annual.pth"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                first_input_batch = torch.zeros(71, 9, 5, 48, 48)
         | 
| 16 | 
            +
                # first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
         | 
| 17 | 
            +
                st.daily_model = FPN(opt, first_input_batch, opt.win_size)
         | 
| 18 | 
            +
                st.annual_model = SimpleNN(opt)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                if torch.cuda.is_available():
         | 
| 21 | 
            +
                    st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
         | 
| 22 | 
            +
                    st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
         | 
| 23 | 
            +
                    st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
         | 
| 24 | 
            +
                    st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
         | 
| 25 | 
            +
                else:
         | 
| 26 | 
            +
                    st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
         | 
| 27 | 
            +
                    st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
         | 
| 28 | 
            +
                    st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
         | 
| 29 | 
            +
                    st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                print('trying to resume previous saved models...')
         | 
| 32 | 
            +
                state = resume(
         | 
| 33 | 
            +
                    os.path.join(opt.resume_path, best_model_daily_file_name),
         | 
| 34 | 
            +
                    model=st.daily_model, optimizer=None)
         | 
| 35 | 
            +
                state = resume(
         | 
| 36 | 
            +
                    os.path.join(opt.resume_path, best_model_annual_file_name),
         | 
| 37 | 
            +
                    model=st.annual_model, optimizer=None)
         | 
| 38 | 
            +
                st.daily_model = st.daily_model.eval()
         | 
| 39 | 
            +
                st.annual_model = st.annual_model.eval()
         | 
| 40 | 
            +
             | 
| 41 |  | 
| 42 | 
             
            # Load Model
         | 
| 43 | 
             
            # @title Load pretrained weights
         | 
| 44 |  | 
| 45 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 46 |  | 
| 47 | 
             
            st.title('Lombardia Sentinel 2 daily Crop Mapping')
         | 
| 48 | 
             
            st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
         | 
|  | |
| 88 | 
             
                        if torch.cuda.is_available():
         | 
| 89 | 
             
                            x_dailies = x_dailies.cuda()
         | 
| 90 |  | 
| 91 | 
            +
                        feat_daily, outs_daily = st.daily_model.forward(x_dailies)
         | 
| 92 | 
             
                        # return to original size of batch and year
         | 
| 93 | 
             
                        outs_daily = outs_daily.view(
         | 
| 94 | 
             
                            opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
         | 
| 95 | 
             
                        feat_daily = feat_daily.view(
         | 
| 96 | 
             
                            opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
         | 
| 97 |  | 
| 98 | 
            +
                        _, out_annual = st.annual_model.forward(feat_daily)
         | 
| 99 | 
             
                        pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
         | 
| 100 | 
             
                        pred_annual = pred_annual.cpu().numpy()
         | 
| 101 | 
             
                        # Remapping the labels
         | 
|  | |
| 161 | 
             
                                           st.paths, index=st.paths.index('patch-pred-nn.tif'))
         | 
| 162 |  | 
| 163 | 
             
                file_path = os.path.join(folder, file_picker)
         | 
| 164 | 
            +
                # print(file_path)
         | 
| 165 | 
             
                target, profile = read(file_path)
         | 
| 166 | 
             
                target = np.squeeze(target)
         | 
| 167 | 
             
                target = [classes_color_map[p] for p in target]
         | 
|  | |
| 172 |  | 
| 173 | 
             
                markdown_legend = ''
         | 
| 174 | 
             
                for c, l in zip(classes_color_map, labels_map):
         | 
| 175 | 
            +
                    # print(colors.to_hex(c))
         | 
| 176 | 
             
                    markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
         | 
| 177 |  | 
| 178 | 
             
                col1, col2 = st.columns(2)
         | 
