siddharth060104 commited on
Commit
5ac9d87
·
verified ·
1 Parent(s): f9c0e64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -265
app.py CHANGED
@@ -1,265 +1,265 @@
1
- from ultralytics import YOLO
2
- import cv2
3
- from stockfish import Stockfish
4
- import os
5
- import numpy as np
6
- import streamlit as st
7
-
8
- # Constants
9
- FEN_MAPPING = {
10
- "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
11
- "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
12
- }
13
- GRID_BORDER = 10 # Border size in pixels
14
- GRID_SIZE = 204 # Effective grid size (10px to 214px)
15
- BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
16
- X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
17
- Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
18
-
19
- # Functions
20
- def get_grid_coordinate(pixel_x, pixel_y):
21
- """
22
- Function to determine the grid coordinate of a pixel, considering a 10px border and
23
- the grid where bottom-left is (a, 1) and top-left is (h, 8).
24
- """
25
- # Grid settings
26
- border = 10 # 10px border
27
- grid_size = 204 # Effective grid size (10px to 214px)
28
- block_size = grid_size // 8 # Each block is ~25px
29
-
30
- x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
31
- y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
32
-
33
- # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
34
- adjusted_x = pixel_x - border
35
- adjusted_y = pixel_y - border
36
-
37
- # Check bounds
38
- if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
39
- return "Pixel outside grid bounds"
40
-
41
- # Determine the grid column and row
42
- x_index = adjusted_x // block_size
43
- y_index = adjusted_y // block_size
44
-
45
- if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
46
- return "Pixel outside grid bounds"
47
-
48
- # Convert indices to grid coordinates
49
- x_index = adjusted_x // block_size # Determine the column index (0-7)
50
- y_index = adjusted_y // block_size # Determine the row index (0-7)
51
-
52
- # Convert row index to the correct label, with '8' at the bottom
53
- y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
54
- x_label = x_labels[x_index]
55
- y_label = 8 - y_labeld + 1
56
-
57
- return f"{x_label}{y_label}"
58
-
59
- def predict_next_move(fen, stockfish):
60
- """
61
- Predict the next move using Stockfish.
62
- """
63
- if stockfish.is_fen_valid(fen):
64
- stockfish.set_fen_position(fen)
65
- else:
66
- return "Invalid FEN notation!"
67
-
68
- best_move = stockfish.get_best_move()
69
- ans = transform_string(best_move)
70
- return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
71
-
72
-
73
-
74
-
75
- def process_image(image_path):
76
- # Ensure output directory exists
77
- if not os.path.exists('output'):
78
- os.makedirs('output')
79
-
80
- # Load the segmentation model
81
- segmentation_model = YOLO("segmentation.pt")
82
-
83
- # Run inference to get segmentation results
84
- results = segmentation_model.predict(
85
- source=image_path,
86
- conf=0.8 # Confidence threshold
87
- )
88
-
89
- # Initialize variables for the segmented mask and bounding box
90
- segmentation_mask = None
91
- bbox = None
92
-
93
- for result in results:
94
- if result.boxes.conf[0] >= 0.8: # Filter results by confidence
95
- segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
96
- bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
97
- break
98
-
99
- if segmentation_mask is None:
100
- print("No segmentation mask with confidence above 0.8 found.")
101
- return None
102
-
103
- # Load the image
104
- image = cv2.imread(image_path)
105
-
106
- # Resize segmentation mask to match the input image dimensions
107
- segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
108
-
109
- # Extract bounding box coordinates
110
- if bbox is not None:
111
- x1, y1, x2, y2 = bbox
112
- # Crop the segmented region based on the bounding box
113
- cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
114
-
115
- # Save the cropped segmented image
116
- cropped_image_path = 'output/cropped_segment.jpg'
117
- cv2.imwrite(cropped_image_path, cropped_segment)
118
- print(f"Cropped segmented image saved to {cropped_image_path}")
119
-
120
- st.image(cropped_segment, caption="Uploaded Image", use_column_width=True)
121
- # Return the cropped image
122
- return cropped_segment
123
-
124
- def transform_string(input_str):
125
- # Remove extra spaces and convert to lowercase
126
- input_str = input_str.strip().lower()
127
-
128
- # Check if input is valid
129
- if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
130
- not input_str[2].isalpha() or not input_str[3].isdigit():
131
- return "Invalid input"
132
-
133
- # Define mappings
134
- letter_mapping = {
135
- 'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
136
- 'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
137
- }
138
- number_mapping = {
139
- '1': '8', '2': '7', '3': '6', '4': '5',
140
- '5': '4', '6': '3', '7': '2', '8': '1'
141
- }
142
-
143
- # Transform string
144
- result = ""
145
- for i, char in enumerate(input_str):
146
- if i % 2 == 0: # Letters
147
- result += letter_mapping.get(char, "Invalid")
148
- else: # Numbers
149
- result += number_mapping.get(char, "Invalid")
150
-
151
- # Check for invalid transformations
152
- if "Invalid" in result:
153
- return "Invalid input"
154
-
155
- return result
156
-
157
-
158
-
159
- # Streamlit app
160
- def main():
161
- st.title("Chessboard Position Detection and Move Prediction")
162
-
163
- # User uploads an image or captures it from their camera
164
- image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
165
-
166
- if image_file is not None:
167
- # Save the image to a temporary file
168
- temp_dir = "temp_images"
169
- os.makedirs(temp_dir, exist_ok=True)
170
- temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
171
- with open(temp_file_path, "wb") as f:
172
- f.write(image_file.getbuffer())
173
-
174
- # Process the image using its file path
175
- processed_image = process_image(temp_file_path)
176
-
177
- if processed_image is not None:
178
- # Resize the image to 224x224
179
- processed_image = cv2.resize(processed_image, (224, 224))
180
- height, width, _ = processed_image.shape
181
-
182
- # Initialize the YOLO model
183
- model = YOLO("standard.pt") # Replace with your trained model weights file
184
-
185
- # Run detection
186
- results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
187
-
188
- # Initialize the board for FEN (empty rows represented by "8")
189
- board = [["8"] * 8 for _ in range(8)]
190
-
191
- # Extract predictions and map to FEN board
192
- for result in results[0].boxes:
193
- x1, y1, x2, y2 = result.xyxy[0].tolist()
194
- class_id = int(result.cls[0])
195
- class_name = model.names[class_id]
196
-
197
- # Convert class_name to FEN notation
198
- fen_piece = FEN_MAPPING.get(class_name, None)
199
- if not fen_piece:
200
- continue
201
-
202
- # Calculate the center of the bounding box
203
- center_x = (x1 + x2) / 2
204
- center_y = (y1 + y2) / 2
205
-
206
- # Convert to integer pixel coordinates
207
- pixel_x = int(center_x)
208
- pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
209
-
210
- # Get grid coordinate
211
- grid_position = get_grid_coordinate(pixel_x, pixel_y)
212
-
213
- if grid_position != "Pixel outside grid bounds":
214
- file = ord(grid_position[0]) - ord('a') # Column index (0-7)
215
- rank = int(grid_position[1]) - 1 # Row index (0-7)
216
-
217
- # Place the piece on the board
218
- board[7 - rank][file] = fen_piece # Flip rank index for FEN
219
-
220
- # Generate the FEN string
221
- fen_rows = []
222
- for row in board:
223
- fen_row = ""
224
- empty_count = 0
225
- for cell in row:
226
- if cell == "8":
227
- empty_count += 1
228
- else:
229
- if empty_count > 0:
230
- fen_row += str(empty_count)
231
- empty_count = 0
232
- fen_row += cell
233
- if empty_count > 0:
234
- fen_row += str(empty_count)
235
- fen_rows.append(fen_row)
236
-
237
- position_fen = "/".join(fen_rows)
238
-
239
- # Ask the user for the next move side
240
- move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
241
- move_side = "w" if move_side.startswith("w") else "b"
242
-
243
- # Append the full FEN string continuation
244
- fen_notation = f"{position_fen} {move_side} - - 0 0"
245
-
246
- st.subheader("Generated FEN Notation:")
247
- st.code(fen_notation)
248
-
249
- # Initialize the Stockfish engine
250
- stockfish = Stockfish(
251
- path=r"stockfish-windows-x86-64-avx2.exe", # Replace with your Stockfish path"
252
- depth=15,
253
- parameters={"Threads": 2, "Minimum Thinking Time": 30}
254
- )
255
-
256
- # Predict the next move
257
- next_move = predict_next_move(fen_notation, stockfish)
258
- st.subheader("Stockfish Recommended Move:")
259
- st.write(next_move)
260
-
261
- else:
262
- st.error("Failed to process the image. Please try again.")
263
-
264
- if __name__ == "__main__":
265
- main()
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ from stockfish import Stockfish
4
+ import os
5
+ import numpy as np
6
+ import streamlit as st
7
+
8
+ # Constants
9
+ FEN_MAPPING = {
10
+ "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
11
+ "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
12
+ }
13
+ GRID_BORDER = 10 # Border size in pixels
14
+ GRID_SIZE = 204 # Effective grid size (10px to 214px)
15
+ BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
16
+ X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
17
+ Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
18
+
19
+ # Functions
20
+ def get_grid_coordinate(pixel_x, pixel_y):
21
+ """
22
+ Function to determine the grid coordinate of a pixel, considering a 10px border and
23
+ the grid where bottom-left is (a, 1) and top-left is (h, 8).
24
+ """
25
+ # Grid settings
26
+ border = 10 # 10px border
27
+ grid_size = 204 # Effective grid size (10px to 214px)
28
+ block_size = grid_size // 8 # Each block is ~25px
29
+
30
+ x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
31
+ y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
32
+
33
+ # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
34
+ adjusted_x = pixel_x - border
35
+ adjusted_y = pixel_y - border
36
+
37
+ # Check bounds
38
+ if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
39
+ return "Pixel outside grid bounds"
40
+
41
+ # Determine the grid column and row
42
+ x_index = adjusted_x // block_size
43
+ y_index = adjusted_y // block_size
44
+
45
+ if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
46
+ return "Pixel outside grid bounds"
47
+
48
+ # Convert indices to grid coordinates
49
+ x_index = adjusted_x // block_size # Determine the column index (0-7)
50
+ y_index = adjusted_y // block_size # Determine the row index (0-7)
51
+
52
+ # Convert row index to the correct label, with '8' at the bottom
53
+ y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
54
+ x_label = x_labels[x_index]
55
+ y_label = 8 - y_labeld + 1
56
+
57
+ return f"{x_label}{y_label}"
58
+
59
+ def predict_next_move(fen, stockfish):
60
+ """
61
+ Predict the next move using Stockfish.
62
+ """
63
+ if stockfish.is_fen_valid(fen):
64
+ stockfish.set_fen_position(fen)
65
+ else:
66
+ return "Invalid FEN notation!"
67
+
68
+ best_move = stockfish.get_best_move()
69
+ ans = transform_string(best_move)
70
+ return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
71
+
72
+
73
+
74
+
75
+ def process_image(image_path):
76
+ # Ensure output directory exists
77
+ if not os.path.exists('output'):
78
+ os.makedirs('output')
79
+
80
+ # Load the segmentation model
81
+ segmentation_model = YOLO("segmentation.pt")
82
+
83
+ # Run inference to get segmentation results
84
+ results = segmentation_model.predict(
85
+ source=image_path,
86
+ conf=0.8 # Confidence threshold
87
+ )
88
+
89
+ # Initialize variables for the segmented mask and bounding box
90
+ segmentation_mask = None
91
+ bbox = None
92
+
93
+ for result in results:
94
+ if result.boxes.conf[0] >= 0.8: # Filter results by confidence
95
+ segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
96
+ bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
97
+ break
98
+
99
+ if segmentation_mask is None:
100
+ print("No segmentation mask with confidence above 0.8 found.")
101
+ return None
102
+
103
+ # Load the image
104
+ image = cv2.imread(image_path)
105
+
106
+ # Resize segmentation mask to match the input image dimensions
107
+ segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
108
+
109
+ # Extract bounding box coordinates
110
+ if bbox is not None:
111
+ x1, y1, x2, y2 = bbox
112
+ # Crop the segmented region based on the bounding box
113
+ cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
114
+
115
+ # Save the cropped segmented image
116
+ cropped_image_path = 'output/cropped_segment.jpg'
117
+ cv2.imwrite(cropped_image_path, cropped_segment)
118
+ print(f"Cropped segmented image saved to {cropped_image_path}")
119
+
120
+ st.image(cropped_segment, caption="Uploaded Image", use_column_width=True)
121
+ # Return the cropped image
122
+ return cropped_segment
123
+
124
+ def transform_string(input_str):
125
+ # Remove extra spaces and convert to lowercase
126
+ input_str = input_str.strip().lower()
127
+
128
+ # Check if input is valid
129
+ if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
130
+ not input_str[2].isalpha() or not input_str[3].isdigit():
131
+ return "Invalid input"
132
+
133
+ # Define mappings
134
+ letter_mapping = {
135
+ 'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
136
+ 'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
137
+ }
138
+ number_mapping = {
139
+ '1': '8', '2': '7', '3': '6', '4': '5',
140
+ '5': '4', '6': '3', '7': '2', '8': '1'
141
+ }
142
+
143
+ # Transform string
144
+ result = ""
145
+ for i, char in enumerate(input_str):
146
+ if i % 2 == 0: # Letters
147
+ result += letter_mapping.get(char, "Invalid")
148
+ else: # Numbers
149
+ result += number_mapping.get(char, "Invalid")
150
+
151
+ # Check for invalid transformations
152
+ if "Invalid" in result:
153
+ return "Invalid input"
154
+
155
+ return result
156
+
157
+
158
+
159
+ # Streamlit app
160
+ def main():
161
+ st.title("Chessboard Position Detection and Move Prediction")
162
+
163
+ # User uploads an image or captures it from their camera
164
+ image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
165
+
166
+ if image_file is not None:
167
+ # Save the image to a temporary file
168
+ temp_dir = "temp_images"
169
+ os.makedirs(temp_dir, exist_ok=True)
170
+ temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
171
+ with open(temp_file_path, "wb") as f:
172
+ f.write(image_file.getbuffer())
173
+
174
+ # Process the image using its file path
175
+ processed_image = process_image(temp_file_path)
176
+
177
+ if processed_image is not None:
178
+ # Resize the image to 224x224
179
+ processed_image = cv2.resize(processed_image, (224, 224))
180
+ height, width, _ = processed_image.shape
181
+
182
+ # Initialize the YOLO model
183
+ model = YOLO("standard.pt") # Replace with your trained model weights file
184
+
185
+ # Run detection
186
+ results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
187
+
188
+ # Initialize the board for FEN (empty rows represented by "8")
189
+ board = [["8"] * 8 for _ in range(8)]
190
+
191
+ # Extract predictions and map to FEN board
192
+ for result in results[0].boxes:
193
+ x1, y1, x2, y2 = result.xyxy[0].tolist()
194
+ class_id = int(result.cls[0])
195
+ class_name = model.names[class_id]
196
+
197
+ # Convert class_name to FEN notation
198
+ fen_piece = FEN_MAPPING.get(class_name, None)
199
+ if not fen_piece:
200
+ continue
201
+
202
+ # Calculate the center of the bounding box
203
+ center_x = (x1 + x2) / 2
204
+ center_y = (y1 + y2) / 2
205
+
206
+ # Convert to integer pixel coordinates
207
+ pixel_x = int(center_x)
208
+ pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
209
+
210
+ # Get grid coordinate
211
+ grid_position = get_grid_coordinate(pixel_x, pixel_y)
212
+
213
+ if grid_position != "Pixel outside grid bounds":
214
+ file = ord(grid_position[0]) - ord('a') # Column index (0-7)
215
+ rank = int(grid_position[1]) - 1 # Row index (0-7)
216
+
217
+ # Place the piece on the board
218
+ board[7 - rank][file] = fen_piece # Flip rank index for FEN
219
+
220
+ # Generate the FEN string
221
+ fen_rows = []
222
+ for row in board:
223
+ fen_row = ""
224
+ empty_count = 0
225
+ for cell in row:
226
+ if cell == "8":
227
+ empty_count += 1
228
+ else:
229
+ if empty_count > 0:
230
+ fen_row += str(empty_count)
231
+ empty_count = 0
232
+ fen_row += cell
233
+ if empty_count > 0:
234
+ fen_row += str(empty_count)
235
+ fen_rows.append(fen_row)
236
+
237
+ position_fen = "/".join(fen_rows)
238
+
239
+ # Ask the user for the next move side
240
+ move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
241
+ move_side = "w" if move_side.startswith("w") else "b"
242
+
243
+ # Append the full FEN string continuation
244
+ fen_notation = f"{position_fen} {move_side} - - 0 0"
245
+
246
+ st.subheader("Generated FEN Notation:")
247
+ st.code(fen_notation)
248
+
249
+ # Initialize the Stockfish engine
250
+ stockfish = Stockfish(
251
+ path="stockfish-windows-x86-64-avx2.exe", # Replace with your Stockfish path"
252
+ depth=15,
253
+ parameters={"Threads": 2, "Minimum Thinking Time": 30}
254
+ )
255
+
256
+ # Predict the next move
257
+ next_move = predict_next_move(fen_notation, stockfish)
258
+ st.subheader("Stockfish Recommended Move:")
259
+ st.write(next_move)
260
+
261
+ else:
262
+ st.error("Failed to process the image. Please try again.")
263
+
264
+ if __name__ == "__main__":
265
+ main()