etemkocaaslan commited on
Commit
cea72b9
·
verified ·
1 Parent(s): d8b8521

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -38
app.py CHANGED
@@ -6,40 +6,22 @@ from typing import Union
6
 
7
  class Preprocessor:
8
  def __init__(self):
9
- """
10
- Initialize the preprocessing transformations.
11
- """
12
  self.transform = transforms.Compose([
13
  transforms.ToTensor(),
14
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
15
  ])
16
 
17
  def __call__(self, image: Image.Image) -> torch.Tensor:
18
- """
19
- Apply preprocessing to the input image.
20
-
21
- :param image: Input image to be preprocessed.
22
- :return: Preprocessed image as a tensor.
23
- """
24
  return self.transform(image)
25
 
26
  class SegmentationModel:
27
  def __init__(self):
28
- """
29
- Initialize and load the DeepLabV3 ResNet101 model.
30
- """
31
  self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
32
  self.model.eval()
33
  if torch.cuda.is_available():
34
  self.model.to('cuda')
35
 
36
  def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
37
- """
38
- Perform inference using the model on the input batch.
39
-
40
- :param input_batch: Batch of preprocessed images.
41
- :return: Model output tensor.
42
- """
43
  with torch.no_grad():
44
  if torch.cuda.is_available():
45
  input_batch = input_batch.to('cuda')
@@ -48,40 +30,22 @@ class SegmentationModel:
48
 
49
  class OutputColorizer:
50
  def __init__(self):
51
- """
52
- Initialize the color palette for segmentations.
53
- """
54
  palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
55
- colors : torch.Tensor = torch.as_tensor([i for i in range(21)])[:, None] * palette
56
  self.colors = (colors % 255).numpy().astype("uint8")
57
 
58
  def colorize(self, output: torch.Tensor) -> Image.Image:
59
- """
60
- Apply colorization to the segmentation output.
61
-
62
- :param output: Segmentation output tensor.
63
- :return: Colorized segmentation image.
64
- """
65
  colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
66
  colorized_output.putpalette(self.colors.ravel())
67
  return colorized_output
68
 
69
  class Segmenter:
70
  def __init__(self):
71
- """
72
- Initialize the Segmenter with Preprocessor, SegmentationModel, and OutputColorizer.
73
- """
74
  self.preprocessor = Preprocessor()
75
  self.model = SegmentationModel()
76
  self.colorizer = OutputColorizer()
77
 
78
  def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
79
- """
80
- Perform the complete segmentation process on the input image.
81
-
82
- :param image: Input image to be segmented.
83
- :return: Colorized segmentation image.
84
- """
85
  input_image: Image.Image = image.convert("RGB")
86
  input_tensor: torch.Tensor = self.preprocessor(input_image)
87
  input_batch: torch.Tensor = input_tensor.unsqueeze(0)
@@ -99,4 +63,5 @@ interface = gr.Interface(
99
  description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101."
100
  )
101
 
102
- interface.launch()
 
 
6
 
7
  class Preprocessor:
8
  def __init__(self):
 
 
 
9
  self.transform = transforms.Compose([
10
  transforms.ToTensor(),
11
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
12
  ])
13
 
14
  def __call__(self, image: Image.Image) -> torch.Tensor:
 
 
 
 
 
 
15
  return self.transform(image)
16
 
17
  class SegmentationModel:
18
  def __init__(self):
 
 
 
19
  self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
20
  self.model.eval()
21
  if torch.cuda.is_available():
22
  self.model.to('cuda')
23
 
24
  def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
25
  with torch.no_grad():
26
  if torch.cuda.is_available():
27
  input_batch = input_batch.to('cuda')
 
30
 
31
  class OutputColorizer:
32
  def __init__(self):
 
 
 
33
  palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
34
+ colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
35
  self.colors = (colors % 255).numpy().astype("uint8")
36
 
37
  def colorize(self, output: torch.Tensor) -> Image.Image:
 
 
 
 
 
 
38
  colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
39
  colorized_output.putpalette(self.colors.ravel())
40
  return colorized_output
41
 
42
  class Segmenter:
43
  def __init__(self):
 
 
 
44
  self.preprocessor = Preprocessor()
45
  self.model = SegmentationModel()
46
  self.colorizer = OutputColorizer()
47
 
48
  def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
 
 
 
 
 
 
49
  input_image: Image.Image = image.convert("RGB")
50
  input_tensor: torch.Tensor = self.preprocessor(input_image)
51
  input_batch: torch.Tensor = input_tensor.unsqueeze(0)
 
63
  description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101."
64
  )
65
 
66
+ if __name__ == "__main__":
67
+ interface.launch()