kerzel commited on
Commit
5a1b5f2
·
1 Parent(s): bf773e4

full code again

Browse files
Files changed (2) hide show
  1. app.py +73 -23
  2. postBuild +12 -0
app.py CHANGED
@@ -1,46 +1,96 @@
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
- from PIL import Image, ImageDraw
5
 
6
- IMAGE_PATH = "output_image.png"
7
- CSV_PATH = "damage_list.csv"
 
 
 
 
8
 
9
- def dummy_damage_classification(img, threshold):
10
- if img is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  return None, None, None
12
 
13
- # Convert to numpy array
14
- image_np = np.array(img)
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Draw a red square on the image
17
- image_pil = Image.fromarray(image_np).convert("RGB")
18
- draw = ImageDraw.Draw(image_pil)
19
- draw.rectangle([10, 10, 50, 50], outline="red", width=3)
20
- image_pil.save(IMAGE_PATH)
21
 
22
- # Create dummy CSV file
23
- df = pd.DataFrame({"x": [20], "y": [30], "damage_type": ["Dummy"]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  df.to_csv(CSV_PATH, index=False)
25
 
26
- # Return PIL image for display, and file paths (strings) for download buttons
27
- return image_pil, IMAGE_PATH, CSV_PATH
28
 
29
  with gr.Blocks() as app:
30
- gr.Markdown("# Minimal Damage Classifier Demo")
31
 
32
  image_input = gr.Image(label="Upload SEM Image")
33
- threshold_input = gr.Number(value=20, label="Threshold")
 
 
34
 
35
  output_image = gr.Image(label="Classified Image")
36
  download_image_btn = gr.DownloadButton(label="Download Image")
37
  download_csv_btn = gr.DownloadButton(label="Download CSV")
38
 
39
- btn = gr.Button("Run Classification")
40
- btn.click(
41
- dummy_damage_classification,
42
- inputs=[image_input, threshold_input],
43
- outputs=[output_image, download_image_btn, download_csv_btn]
44
  )
45
 
46
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
+ from PIL import Image
5
 
6
+ # Your helper imports and tensorflow models assumed to be loaded here:
7
+ import clustering
8
+ import utils
9
+ from tensorflow import keras
10
+ import logging
11
+ logging.getLogger().setLevel(logging.INFO)
12
 
13
+ # Paths to save outputs
14
+ IMAGE_PATH = "classified_damage_sites.png"
15
+ CSV_PATH = "classified_damage_sites.csv"
16
+
17
+ # Load models once (adjust filenames as needed)
18
+ model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
19
+ model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
20
+
21
+ damage_classes = {3: "Martensite", 2: "Interface", 0: "Notch", 1: "Shadowing"}
22
+ model1_windowsize = [250, 250]
23
+ model2_windowsize = [100, 100]
24
+
25
+ def damage_classification(SEM_image, image_threshold, model1_threshold, model2_threshold):
26
+ if SEM_image is None:
27
+ logging.error("No image provided")
28
  return None, None, None
29
 
30
+ damage_sites = {}
31
+
32
+ # Step 1: Clustering to find damage centroids
33
+ all_centroids = clustering.get_centroids(
34
+ SEM_image,
35
+ image_threshold=image_threshold,
36
+ fill_holes=True,
37
+ filter_close_centroids=True,
38
+ )
39
+
40
+ for c in all_centroids:
41
+ damage_sites[(c[0], c[1])] = "Not Classified"
42
 
43
+ # Step 2: Model 1 to identify inclusions
44
+ images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
45
+ y1_pred = model1.predict(np.asarray(images_model1, dtype=float))
46
+ inclusions = np.where(y1_pred[:, 0] > model1_threshold)[0]
 
47
 
48
+ for idx in inclusions:
49
+ coord = all_centroids[idx]
50
+ damage_sites[(coord[0], coord[1])] = "Inclusion"
51
+
52
+ # Step 3: Model 2 to classify remaining damage types
53
+ centroids_model2 = [list(k) for k, v in damage_sites.items() if v == "Not Classified"]
54
+ if centroids_model2:
55
+ images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
56
+ y2_pred = model2.predict(np.asarray(images_model2, dtype=float))
57
+ damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
58
+
59
+ for i in range(len(damage_index[0])):
60
+ sample_idx = damage_index[0][i]
61
+ class_idx = damage_index[1][i]
62
+ label = damage_classes.get(class_idx, "Unknown")
63
+ coord = centroids_model2[sample_idx]
64
+ damage_sites[(coord[0], coord[1])] = label
65
+
66
+ # Step 4: Draw boxes on image and save output image
67
+ image_with_boxes = utils.show_boxes(SEM_image, damage_sites, save_image=True, image_path=IMAGE_PATH)
68
+
69
+ # Step 5: Export CSV file
70
+ data = [[x, y, label] for (x, y), label in damage_sites.items()]
71
+ df = pd.DataFrame(data, columns=["x", "y", "damage_type"])
72
  df.to_csv(CSV_PATH, index=False)
73
 
74
+ return image_with_boxes, IMAGE_PATH, CSV_PATH
75
+
76
 
77
  with gr.Blocks() as app:
78
+ gr.Markdown("# Damage Classification in Dual Phase Steels")
79
 
80
  image_input = gr.Image(label="Upload SEM Image")
81
+ cluster_threshold_input = gr.Number(value=20, label="Cluster Threshold")
82
+ model1_threshold_input = gr.Number(value=0.7, label="Model 1 Threshold")
83
+ model2_threshold_input = gr.Number(value=0.5, label="Model 2 Threshold")
84
 
85
  output_image = gr.Image(label="Classified Image")
86
  download_image_btn = gr.DownloadButton(label="Download Image")
87
  download_csv_btn = gr.DownloadButton(label="Download CSV")
88
 
89
+ classify_btn = gr.Button("Run Classification")
90
+ classify_btn.click(
91
+ damage_classification,
92
+ inputs=[image_input, cluster_threshold_input, model1_threshold_input, model2_threshold_input],
93
+ outputs=[output_image, download_image_btn, download_csv_btn],
94
  )
95
 
96
  if __name__ == "__main__":
postBuild ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # postBuild script to run after pip install from requirements.txt
3
+
4
+ echo "Running postBuild script..."
5
+
6
+ # Upgrade or install a specific version of gradio (example: 4.44.1)
7
+ pip install --upgrade gradio==4.44.1
8
+
9
+ # (Optional) install or upgrade other packages if needed
10
+ # pip install --upgrade some-package==version
11
+
12
+ echo "postBuild complete."