akhaliq HF staff commited on
Commit
0f48bb9
·
1 Parent(s): a351d24

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from PIL import Image
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ import tensorflow_hub as hub
7
+ import matplotlib.pyplot as plt
8
+ os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
9
+
10
+ os.system("wget https://user-images.githubusercontent.com/12981474/40157448-eff91f06-5953-11e8-9a37-f6b5693fa03f.png -O original.png")
11
+
12
+ # Declaring Constants
13
+ IMAGE_PATH = "original.png"
14
+ SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
15
+
16
+ def preprocess_image(image_path):
17
+ """ Loads image from path and preprocesses to make it model ready
18
+ Args:
19
+ image_path: Path to the image file
20
+ """
21
+ hr_image = tf.image.decode_image(tf.io.read_file(image_path))
22
+ # If PNG, remove the alpha channel. The model only supports
23
+ # images with 3 color channels.
24
+ if hr_image.shape[-1] == 4:
25
+ hr_image = hr_image[...,:-1]
26
+ hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
27
+ hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
28
+ hr_image = tf.cast(hr_image, tf.float32)
29
+ return tf.expand_dims(hr_image, 0)
30
+
31
+
32
+ def plot_image(image):
33
+ """
34
+ Plots images from image tensors.
35
+ Args:
36
+ image: 3D image tensor. [height, width, channels].
37
+ title: Title to display in the plot.
38
+ """
39
+ image = np.asarray(image)
40
+ image = tf.clip_by_value(image, 0, 255)
41
+ image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
42
+ return image
43
+
44
+ model = hub.load(SAVED_MODEL_PATH)
45
+ def inference(img)
46
+ hr_image = preprocess_image(img)
47
+ start = time.time()
48
+ fake_image = model(hr_image)
49
+ fake_image = tf.squeeze(fake_image)
50
+ print("Time Taken: %f" % (time.time() - start))
51
+ pil_image = plot_image(tf.squeeze(fake_image))
52
+ return pil_image
53
+
54
+ gr.Interface(inference,gr.inputs.Image(type="filepath"),"image").launch()
55
+