File size: 12,348 Bytes
0d917aa
40b16f2
 
0d917aa
40b16f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import gradio as gr
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image
import astropy.units as u
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy import coordinates as coord
from astropy.wcs.utils import skycoord_to_pixel
from astroquery.simbad import Simbad
import pandas as pd
import matplotlib.patches as patches
# Increase the limit (set to a value larger than the pixel count of your image)
Image.MAX_IMAGE_PIXELS = None
plt.style.use('dark_background')

# Initialize globals
global_dataframe = pd.DataFrame()
global_data = None
global_header = None

def show_csv(file):
    """
    Displays the uploaded CSV file as a table.
    """
    global global_dataframe
    
    try:
        # Read the CSV file into a pandas DataFrame
        df = pd.read_csv(file.name, index_col=0)
        global_dataframe = df  # Store the dataframe globally for filtering
        # Extract unique types from the "type" column
        if "TYPE" in df.columns:
            unique_types = df["TYPE"].unique().tolist()
            return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True)
        else:
            return "Error: CSV does not contain a 'type' column.", None
    except Exception as e:
        return f"Error: {str(e)}", None
    
# Define a function to be called when the button is clicked
def query_update_table():
    """
    Displays the uploaded CSV file as a table.
    """
    global global_dataframe, global_header, global_data
    
    try:
        # Read the CSV file into a pandas DataFrame
        #df = pd.read_csv('dataframe.csv', index_col=0)

        Simbad.TIMEOUT = 120

        # Define the specific coordinates
        wcs = WCS(global_header).dropaxis(2)
        center_ra = global_header['CRVAL1']
        center_dec = global_header['CRVAL2']
        target_coord = SkyCoord(ra=center_ra, dec=center_dec, unit=(u.deg, u.deg), frame='icrs')
        print(center_ra, center_dec)
        # define the search radius
        radius_deg = max([abs(global_header['CDELT1']),abs(global_header['CDELT2'])])*max([global_header['NAXIS1'],global_header['NAXIS2']])
        radius_deg *= 1
        # Set up the query criteria
        if target_coord.dec.deg > 0:
            custom_query = f"region(CIRCLE, {target_coord.ra.deg} +{target_coord.dec.deg}, {radius_deg}d)"
        else:
            custom_query = f"region(CIRCLE, {target_coord.ra.deg} {target_coord.dec.deg}, {radius_deg}d)"
        
        print(f'Query={custom_query}')

        result_table = Simbad.query_criteria(custom_query, otype='galaxy')
        
        print("received feedback from simbad!!!")
        print(result_table)
        df = result_table.to_pandas().set_index('main_id')
        print(df.columns)
        df['Pixel_Position'] = [skycoord_to_pixel(SkyCoord(v[0],v[1], unit=(u.deg, u.deg), frame='icrs'), wcs) for v in df[['ra','dec']].values]
        print(df['Pixel_Position'])
        df['px'] = df['Pixel_Position'].apply(lambda x: int(x[0]))
        df['py'] = df['Pixel_Position'].apply(lambda x: int(x[1]))
        mask = (df.px>0)&(df.px< global_data.shape[1])&(df.py>0)&(df.py<global_data.shape[0])
        print(df)
        df = df[mask]
        df = df.reset_index()
        df['TYPE'] = df['main_id'].apply(lambda x: x.split(' ')[0].split('+')[0])
        df = df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True)
        print(df)
        df = df.iloc[:200]
        global_dataframe = df  # Store the dataframe globally for filtering

        # Extract unique types from the "type" column
        if "TYPE" in df.columns:
            unique_types = df["TYPE"].unique().tolist()
            return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True)
        else:
            return "Error: CSV does not contain a 'type' column.", None
    except Exception as e:
        return f"Error: {str(e)}", None


def load_fits_image(file, type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method):
    """
    Displays the data from the uploaded FITS file.
    """
    global global_header, global_data

    # Open the FITS file
    hdu = fits.open(file)
    data = hdu[0].data  # Access the primary HDU data
    data = np.swapaxes(np.swapaxes(data,0,2),0,1)#.astype(np.float)
    #data = (data*255).astype(np.uint8)  # Access the primary HDU data
    global_data = data

    # get fits header
    header = hdu[0].header
    global_header = header
 #selected_types, title, selected_axis_options, num_rows, patch_size, patch_color, sort_method
    return update_images_and_tables(type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method)
 
def update_images_and_tables(selected_types, title, selected_axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method):
    global global_dataframe, global_header, global_data
    
    if selected_types and not global_dataframe.empty:
        # Filter the dataframe based on the selected types
        filtered_df = global_dataframe[global_dataframe["TYPE"].isin(selected_types)]
        mask = (filtered_df.px-patch_size//2 > 0)&(filtered_df.px+patch_size//2 < global_data.shape[1])&(filtered_df.py-patch_size//2 > 0)&(filtered_df.py+patch_size//2 < global_data.shape[0])
        filtered_df = filtered_df[mask]
    else:
        filtered_df = None

    if not filtered_df is None:
        # Sort the dataframe based on the sorting method
        if sort_method == "by Catalogue":
            filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True])
            filtered_df = filtered_df.sort_values(by='TYPE', ascending=True).reset_index(drop=True)
        elif sort_method == "by x":
            filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True)
        elif sort_method == "by y":
            filtered_df = filtered_df.sort_values(by=['py', 'px'], ascending=[True, True]).reset_index(drop=True)

    try:
        
        wcs = WCS(global_header).dropaxis(2)
        
        ratio = global_data.shape[0]/global_data.shape[1]
        # Plot WCS
        fig = plt.figure(figsize=(ratio*scale,scale))
        ax = fig.add_subplot(projection=wcs, label='overlays')
        ax.imshow(global_data, origin='lower')

        #if not filtered_df is None:
        #    filtered_df.plot.scatter(x='px', y='py', ax=ax, s=15, c=patch_color)

        if "with Grid" in selected_axis_options:
            ax.coords.grid(True, color='white', ls='-', alpha=.5)
        if "with Axis Annotation" in selected_axis_options:
            ax.coords[0].set_axislabel('Right Ascension (J2000)', fontsize=fontsize+2)
            ax.coords[1].set_axislabel('Declination (J2000)', fontsize=fontsize+2)
        else:
            ax.axis('off')
        plt.title(title, fontsize=fontsize+4)

        if not filtered_df is None:
            all_patches = []
            for i,row in filtered_df.iterrows():
                rect = patches.Rectangle((row.px-patch_size//2, row.py-patch_size//2), patch_size, patch_size, alpha=alpha, linewidth=linewidth, edgecolor=patch_color, facecolor='none')
                ax.add_patch(rect)
                ax.text(row.px,row.py+patch_size//2,str(i+1),
                        ha='center',va='bottom',color=patch_color,fontsize=fontsize)
                patch = global_data[row.py-patch_size//2:row.py+patch_size//2,row.px-patch_size//2:row.px+patch_size//2]
                all_patches.append(patch)
        plt.tight_layout()
        # Convert the plot to an image
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200)
        plt.close(fig)
        buf.seek(0)
        # Convert buffer to PIL Image
        image = Image.open(buf)

        if not filtered_df is None:
            
            m = num_rows
            n = int(np.ceil(len(filtered_df)/m))

            second_scale=max([1,scale//3])
            fig, axarr = plt.subplots(n,m,figsize=(m*second_scale,n*second_scale))
            for i, row in filtered_df.iterrows():
                    ax = axarr[i//m,i%m]
                    ax.imshow(all_patches[i][::-1])
                    ax.set_title(row.main_id, fontsize=fontsize-2)
                    ax.set_xticks([])
                    ax.set_yticks([])
                    ax.text(2,2,str(i+1)[:30],ha='left',va='top',fontsize=fontsize+6)
            for i in np.arange(len(all_patches),m*n):
                    ax = axarr[i//m,i%m]
                    ax.axis('off')
            plt.tight_layout()

            # Convert the plot to an image
            second_buf = io.BytesIO()
            plt.savefig(second_buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200)
            plt.close(fig)
            second_buf.seek(0)
            # Convert buffer to PIL Image
            patches_image = Image.open(second_buf)

            return filtered_df, image, patches_image
        else:
            return filtered_df, image, None
    
    except Exception as e:
        return f"Error: {str(e)}"


# Gradio interface
with gr.Blocks(css=".btn-green {background-color: green; color: white;}") as gui:
    gr.Markdown("# What's in my image?")
    # Options Area
    with gr.Row() as options_gui:
        num_rows = gr.Number(label="Number of Rows", value=16, minimum=2, precision=0, interactive=True)
        title = gr.Textbox(label="Image Title", value="Custom Title", interactive=True)
        patch_size = gr.Slider(label="Patch Size", minimum=16, maximum=128, step=8, value=32, 
        interactive=True)
        fontsize = gr.Slider(label="Fontsize", minimum=6, maximum=26, step=1, value=10, 
        interactive=True)
        alpha = gr.Slider(label="Alpha", minimum=0., maximum=1., step=.1, value=1., 
        interactive=True)
        linewidth = gr.Slider(label="Linewidth", minimum=1, maximum=4, step=1, value=1, 
        interactive=True)
        scale = gr.Slider(label="Scale", minimum=1, maximum=20, step=1, value=8, 
        interactive=True)

        patch_color = gr.ColorPicker(label="Patch Color", value="#FFFFFF", interactive=True)
        sort_method = gr.Dropdown(label="Sorting Method", choices=["by Catalogue", "by x", "by y"], value="by Catalogue", interactive=True)
        axis_options = gr.CheckboxGroup(
            label="Select options",
            choices=["with Grid", "with Axis Annotation"],
            value=["with Grid", "with Axis Annotation"],  # Preselected values
            interactive=True  # Makes it interactive
        )
    gr.Markdown("Upload a plate solved `.fits` file (32 bit) to display its content.")
    
    file_input = gr.File(label="Upload .fits File", type="filepath")
    #file_input_csv = gr.File(label="Upload .csv File")
    greet_button = gr.Button("Query Simbad for Galaxies")  # Create the button
    fits_image = gr.Image(label="Input Image", type="pil")
    type_checkboxes = gr.CheckboxGroup(label="Select Catalogue")
    patches_image = gr.Image(label="Patches Image", type="pil")
    csv_table = gr.DataFrame(label="CSV Table")

    track_options = [type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method]

    file_input.change(load_fits_image, 
                      inputs=[file_input] + track_options, 
                      outputs=[csv_table,fits_image,patches_image])
        
    for option_i in track_options:
        option_i.change(update_images_and_tables,
                       inputs=track_options, 
                      outputs=[csv_table,fits_image,patches_image])

    # Display CSV table
    #file_input_csv.change(show_csv, 
    #                      inputs=file_input_csv, 
    #                      outputs=[csv_table, type_checkboxes])
        
    greet_button.click(query_update_table, inputs=None, outputs=[csv_table, type_checkboxes])

    # Update the selected checkboxes change
    type_checkboxes.change(update_images_and_tables, 
                           inputs=track_options,
                           outputs=[csv_table,fits_image,patches_image])

gui.launch(debug=True)