Spaces:
Runtime error
Runtime error
| """ | |
| Visualisation utils. | |
| """ | |
| import chess | |
| import chess.svg | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000) | |
| ALPHA = 1.0 | |
| def render_heatmap( | |
| board, | |
| heatmap, | |
| square=None, | |
| vmin=None, | |
| vmax=None, | |
| arrows=None, | |
| normalise="none", | |
| ): | |
| """ | |
| Render a heatmap on the board. | |
| """ | |
| if normalise == "abs": | |
| a_max = heatmap.abs().max() | |
| if a_max != 0: | |
| heatmap = heatmap / a_max | |
| vmin = -1 | |
| vmax = 1 | |
| if vmin is None: | |
| vmin = heatmap.min() | |
| if vmax is None: | |
| vmax = heatmap.max() | |
| norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) | |
| color_dict = {} | |
| for square_index in range(64): | |
| color = COLOR_MAP(norm(heatmap[square_index])) | |
| color = (*color[:3], ALPHA) | |
| color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) | |
| fig = plt.figure(figsize=(6, 0.6)) | |
| ax = plt.gca() | |
| ax.axis("off") | |
| fig.colorbar( | |
| matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP), | |
| ax=ax, | |
| orientation="horizontal", | |
| fraction=1.0, | |
| ) | |
| if square is not None: | |
| try: | |
| check = chess.parse_square(square) | |
| except ValueError: | |
| check = None | |
| else: | |
| check = None | |
| if arrows is None: | |
| arrows = [] | |
| plt.close() | |
| return ( | |
| chess.svg.board( | |
| board, | |
| check=check, | |
| fill=color_dict, | |
| size=350, | |
| arrows=arrows, | |
| ), | |
| fig, | |
| ) | |