fffiloni commited on
Commit
7e3803b
·
1 Parent(s): 30681b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -35
app.py CHANGED
@@ -59,10 +59,10 @@ import tempfile
59
  from pathlib import Path
60
  from urllib.request import urlretrieve
61
 
62
-
63
- video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
64
- video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
65
- _ = urlretrieve(video_url, video_path)
66
 
67
  #########################
68
  # :func:`~torchvision.io.read_video` returns the video frames, audio frames and
@@ -73,11 +73,11 @@ _ = urlretrieve(video_url, video_path)
73
  # namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a
74
  # single model input.
75
 
76
- from torchvision.io import read_video
77
- frames, _, _ = read_video(str(video_path), output_format="TCHW")
78
 
79
- img1= [frames[100]
80
- img2 = [frames[101]
81
 
82
  #########################
83
  # The RAFT model accepts RGB images. We first get the frames from
@@ -86,21 +86,21 @@ img2 = [frames[101]
86
  # weights in order to preprocess the input and rescale its values to the
87
  # required ``[-1, 1]`` interval.
88
 
89
- from torchvision.models.optical_flow import Raft_Large_Weights
90
 
91
- weights = Raft_Large_Weights.DEFAULT
92
- transforms = weights.transforms()
93
 
94
 
95
- def preprocess(img, img2):
96
- img1 = F.resize(img1, size=[520, 960])
97
- img2 = F.resize(img2, size=[520, 960])
98
- return transforms(img1, img2)
99
 
100
 
101
- img1, img2 = preprocess(img1, img2)
102
 
103
- print(f"shape = {img1.shape}, dtype = {img1.dtype}")
104
 
105
 
106
  ####################################
@@ -112,17 +112,17 @@ print(f"shape = {img1.shape}, dtype = {img1.dtype}")
112
  # We also provide the :func:`~torchvision.models.optical_flow.raft_small` model
113
  # builder, which is smaller and faster to run, sacrificing a bit of accuracy.
114
 
115
- from torchvision.models.optical_flow import raft_large
116
 
117
  # If you can, run this example on a GPU, it will be a lot faster.
118
- device = "cuda" if torch.cuda.is_available() else "cpu"
119
 
120
- model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
121
- model = model.eval()
122
 
123
- list_of_flows = model(img1.to(device), img2.to(device))
124
- print(f"type = {type(list_of_flows)}")
125
- print(f"length = {len(list_of_flows)} = number of iterations of the model")
126
 
127
  ####################################
128
  # The RAFT model outputs lists of predicted flows where each entry is a
@@ -137,10 +137,10 @@ print(f"length = {len(list_of_flows)} = number of iterations of the model")
137
  # vertical displacement of each pixel from the first image to the second image.
138
  # Note that the predicted flows are in "pixel" unit, they are not normalized
139
  # w.r.t. the dimensions of the images.
140
- predicted_flows = list_of_flows[-1]
141
- print(f"dtype = {predicted_flows.dtype}")
142
- print(f"shape = {predicted_flows.shape} = (N, 2, H, W)")
143
- print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}")
144
 
145
 
146
  ####################################
@@ -155,15 +155,13 @@ print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}")
155
  # of the ball in the first image (going to the left) and in the second image
156
  # (going up).
157
 
158
- from torchvision.utils import flow_to_image
159
 
160
- flow_imgs = flow_to_image(predicted_flows)
161
 
162
- # The images have been mapped into [-1, 1] but for plotting we want them in [0, 1]
163
- img1 = [(img1 + 1) / 2 for img1 in img1]
164
 
165
- grid = [[img1, flow_img] for (img1, flow_img) in zip(img1, flow_imgs)]
166
- plot(grid)
167
 
168
  ####################################
169
  # Bonus: Creating GIFs of predicted flows
@@ -208,4 +206,6 @@ def write_flo(filename, flow):
208
  w.tofile(f)
209
  h.tofile(f)
210
  flow.tofile(f)
211
- f.close()
 
 
 
59
  from pathlib import Path
60
  from urllib.request import urlretrieve
61
 
62
+ def infer():
63
+ video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
64
+ video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
65
+ _ = urlretrieve(video_url, video_path)
66
 
67
  #########################
68
  # :func:`~torchvision.io.read_video` returns the video frames, audio frames and
 
73
  # namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a
74
  # single model input.
75
 
76
+ from torchvision.io import read_video
77
+ frames, _, _ = read_video(str(video_path), output_format="TCHW")
78
 
79
+ img1= [frames[100]
80
+ img2 = [frames[101]
81
 
82
  #########################
83
  # The RAFT model accepts RGB images. We first get the frames from
 
86
  # weights in order to preprocess the input and rescale its values to the
87
  # required ``[-1, 1]`` interval.
88
 
89
+ from torchvision.models.optical_flow import Raft_Large_Weights
90
 
91
+ weights = Raft_Large_Weights.DEFAULT
92
+ transforms = weights.transforms()
93
 
94
 
95
+ def preprocess(img, img2):
96
+ img1 = F.resize(img1, size=[520, 960])
97
+ img2 = F.resize(img2, size=[520, 960])
98
+ return transforms(img1, img2)
99
 
100
 
101
+ img1, img2 = preprocess(img1, img2)
102
 
103
+ print(f"shape = {img1.shape}, dtype = {img1.dtype}")
104
 
105
 
106
  ####################################
 
112
  # We also provide the :func:`~torchvision.models.optical_flow.raft_small` model
113
  # builder, which is smaller and faster to run, sacrificing a bit of accuracy.
114
 
115
+ from torchvision.models.optical_flow import raft_large
116
 
117
  # If you can, run this example on a GPU, it will be a lot faster.
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
 
120
+ model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
121
+ model = model.eval()
122
 
123
+ list_of_flows = model(img1.to(device), img2.to(device))
124
+ print(f"type = {type(list_of_flows)}")
125
+ print(f"length = {len(list_of_flows)} = number of iterations of the model")
126
 
127
  ####################################
128
  # The RAFT model outputs lists of predicted flows where each entry is a
 
137
  # vertical displacement of each pixel from the first image to the second image.
138
  # Note that the predicted flows are in "pixel" unit, they are not normalized
139
  # w.r.t. the dimensions of the images.
140
+ predicted_flows = list_of_flows[-1]
141
+ print(f"dtype = {predicted_flows.dtype}")
142
+ print(f"shape = {predicted_flows.shape} = (N, 2, H, W)")
143
+ print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}")
144
 
145
 
146
  ####################################
 
155
  # of the ball in the first image (going to the left) and in the second image
156
  # (going up).
157
 
158
+ from torchvision.utils import flow_to_image
159
 
160
+ flow_imgs = flow_to_image(predicted_flows)
161
 
162
+ print(flow_imgs)
 
163
 
164
+ return "done"
 
165
 
166
  ####################################
167
  # Bonus: Creating GIFs of predicted flows
 
206
  w.tofile(f)
207
  h.tofile(f)
208
  flow.tofile(f)
209
+ f.close()
210
+
211
+ gr.Interface(fn=infer, inputs=[], outputs=gr.Textbox()).launch()