bgaspra commited on
Commit
597527d
·
verified ·
1 Parent(s): 81c886c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -12,10 +12,6 @@ from PIL import UnidentifiedImageError, Image
12
  import gradio as gr
13
  import matplotlib.pyplot as plt
14
 
15
- # Ensure TensorFlow uses GPU
16
- print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
17
- assert len(tf.config.list_physical_devices('GPU')) > 0, "No GPU available!"
18
-
19
  # Load the dataset
20
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
21
 
@@ -27,8 +23,22 @@ dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
27
  image_dir = 'civitai_images'
28
  os.makedirs(image_dir, exist_ok=True)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Load the ResNet50 model pretrained on ImageNet
31
- with tf.device('/GPU:0'):
32
  model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
33
 
34
  # Function to extract features
@@ -37,7 +47,7 @@ def extract_features(img_path, model):
37
  img_array = image.img_to_array(img)
38
  img_array = np.expand_dims(img_array, axis=0)
39
  img_array = preprocess_input(img_array)
40
- with tf.device('/GPU:0'):
41
  features = model.predict(img_array)
42
  return features.flatten()
43
 
@@ -136,4 +146,4 @@ interface = gr.Interface(
136
  description="Upload an image and get 5 recommended similar images with model names and distances."
137
  )
138
 
139
- interface.launch()
 
12
  import gradio as gr
13
  import matplotlib.pyplot as plt
14
 
 
 
 
 
15
  # Load the dataset
16
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
17
 
 
23
  image_dir = 'civitai_images'
24
  os.makedirs(image_dir, exist_ok=True)
25
 
26
+ # Try to use GPU, fall back to CPU if not available
27
+ try:
28
+ gpus = tf.config.list_physical_devices('GPU')
29
+ if gpus:
30
+ tf.config.experimental.set_memory_growth(gpus[0], True)
31
+ device = '/GPU:0'
32
+ print("Using GPU")
33
+ else:
34
+ raise RuntimeError("No GPU found")
35
+ except RuntimeError as e:
36
+ print(e)
37
+ device = '/CPU:0'
38
+ print("Using CPU")
39
+
40
  # Load the ResNet50 model pretrained on ImageNet
41
+ with tf.device(device):
42
  model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
43
 
44
  # Function to extract features
 
47
  img_array = image.img_to_array(img)
48
  img_array = np.expand_dims(img_array, axis=0)
49
  img_array = preprocess_input(img_array)
50
+ with tf.device(device):
51
  features = model.predict(img_array)
52
  return features.flatten()
53
 
 
146
  description="Upload an image and get 5 recommended similar images with model names and distances."
147
  )
148
 
149
+ interface.launch()