pawipa commited on
Commit
40b16f2
·
1 Parent(s): 741ae27

first diry commit for checking dependencies.

Browse files
Files changed (2) hide show
  1. app.py +275 -152
  2. requirements.txt +3 -5
app.py CHANGED
@@ -1,154 +1,277 @@
1
  import gradio as gr
 
 
2
  import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
-
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from astropy.io import fits
3
+ import matplotlib.pyplot as plt
4
  import numpy as np
5
+ import io
6
+ from PIL import Image
7
+ import astropy.units as u
8
+ from astropy.wcs import WCS
9
+ from astropy.coordinates import SkyCoord
10
+ from astropy import coordinates as coord
11
+ from astropy.wcs.utils import skycoord_to_pixel
12
+ from astroquery.simbad import Simbad
13
+ import pandas as pd
14
+ import matplotlib.patches as patches
15
+ # Increase the limit (set to a value larger than the pixel count of your image)
16
+ Image.MAX_IMAGE_PIXELS = None
17
+ plt.style.use('dark_background')
18
+
19
+ # Initialize globals
20
+ global_dataframe = pd.DataFrame()
21
+ global_data = None
22
+ global_header = None
23
+
24
+ def show_csv(file):
25
+ """
26
+ Displays the uploaded CSV file as a table.
27
+ """
28
+ global global_dataframe
29
+
30
+ try:
31
+ # Read the CSV file into a pandas DataFrame
32
+ df = pd.read_csv(file.name, index_col=0)
33
+ global_dataframe = df # Store the dataframe globally for filtering
34
+ # Extract unique types from the "type" column
35
+ if "TYPE" in df.columns:
36
+ unique_types = df["TYPE"].unique().tolist()
37
+ return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True)
38
+ else:
39
+ return "Error: CSV does not contain a 'type' column.", None
40
+ except Exception as e:
41
+ return f"Error: {str(e)}", None
42
+
43
+ # Define a function to be called when the button is clicked
44
+ def query_update_table():
45
+ """
46
+ Displays the uploaded CSV file as a table.
47
+ """
48
+ global global_dataframe, global_header, global_data
49
+
50
+ try:
51
+ # Read the CSV file into a pandas DataFrame
52
+ #df = pd.read_csv('dataframe.csv', index_col=0)
53
+
54
+ Simbad.TIMEOUT = 120
55
+
56
+ # Define the specific coordinates
57
+ wcs = WCS(global_header).dropaxis(2)
58
+ center_ra = global_header['CRVAL1']
59
+ center_dec = global_header['CRVAL2']
60
+ target_coord = SkyCoord(ra=center_ra, dec=center_dec, unit=(u.deg, u.deg), frame='icrs')
61
+ print(center_ra, center_dec)
62
+ # define the search radius
63
+ radius_deg = max([abs(global_header['CDELT1']),abs(global_header['CDELT2'])])*max([global_header['NAXIS1'],global_header['NAXIS2']])
64
+ radius_deg *= 1
65
+ # Set up the query criteria
66
+ if target_coord.dec.deg > 0:
67
+ custom_query = f"region(CIRCLE, {target_coord.ra.deg} +{target_coord.dec.deg}, {radius_deg}d)"
68
+ else:
69
+ custom_query = f"region(CIRCLE, {target_coord.ra.deg} {target_coord.dec.deg}, {radius_deg}d)"
70
+
71
+ print(f'Query={custom_query}')
72
+
73
+ result_table = Simbad.query_criteria(custom_query, otype='galaxy')
74
+
75
+ print("received feedback from simbad!!!")
76
+ print(result_table)
77
+ df = result_table.to_pandas().set_index('main_id')
78
+ print(df.columns)
79
+ df['Pixel_Position'] = [skycoord_to_pixel(SkyCoord(v[0],v[1], unit=(u.deg, u.deg), frame='icrs'), wcs) for v in df[['ra','dec']].values]
80
+ print(df['Pixel_Position'])
81
+ df['px'] = df['Pixel_Position'].apply(lambda x: int(x[0]))
82
+ df['py'] = df['Pixel_Position'].apply(lambda x: int(x[1]))
83
+ mask = (df.px>0)&(df.px< global_data.shape[1])&(df.py>0)&(df.py<global_data.shape[0])
84
+ print(df)
85
+ df = df[mask]
86
+ df = df.reset_index()
87
+ df['TYPE'] = df['main_id'].apply(lambda x: x.split(' ')[0].split('+')[0])
88
+ df = df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True)
89
+ print(df)
90
+ df = df.iloc[:200]
91
+ global_dataframe = df # Store the dataframe globally for filtering
92
+
93
+ # Extract unique types from the "type" column
94
+ if "TYPE" in df.columns:
95
+ unique_types = df["TYPE"].unique().tolist()
96
+ return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True)
97
+ else:
98
+ return "Error: CSV does not contain a 'type' column.", None
99
+ except Exception as e:
100
+ return f"Error: {str(e)}", None
101
+
102
+
103
+ def load_fits_image(file, type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method):
104
+ """
105
+ Displays the data from the uploaded FITS file.
106
+ """
107
+ global global_header, global_data
108
+
109
+ # Open the FITS file
110
+ hdu = fits.open(file)
111
+ data = hdu[0].data # Access the primary HDU data
112
+ data = np.swapaxes(np.swapaxes(data,0,2),0,1)#.astype(np.float)
113
+ #data = (data*255).astype(np.uint8) # Access the primary HDU data
114
+ global_data = data
115
+
116
+ # get fits header
117
+ header = hdu[0].header
118
+ global_header = header
119
+ #selected_types, title, selected_axis_options, num_rows, patch_size, patch_color, sort_method
120
+ return update_images_and_tables(type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method)
121
+
122
+ def update_images_and_tables(selected_types, title, selected_axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method):
123
+ global global_dataframe, global_header, global_data
124
+
125
+ if selected_types and not global_dataframe.empty:
126
+ # Filter the dataframe based on the selected types
127
+ filtered_df = global_dataframe[global_dataframe["TYPE"].isin(selected_types)]
128
+ mask = (filtered_df.px-patch_size//2 > 0)&(filtered_df.px+patch_size//2 < global_data.shape[1])&(filtered_df.py-patch_size//2 > 0)&(filtered_df.py+patch_size//2 < global_data.shape[0])
129
+ filtered_df = filtered_df[mask]
130
+ else:
131
+ filtered_df = None
132
+
133
+ if not filtered_df is None:
134
+ # Sort the dataframe based on the sorting method
135
+ if sort_method == "by Catalogue":
136
+ filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True])
137
+ filtered_df = filtered_df.sort_values(by='TYPE', ascending=True).reset_index(drop=True)
138
+ elif sort_method == "by x":
139
+ filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True)
140
+ elif sort_method == "by y":
141
+ filtered_df = filtered_df.sort_values(by=['py', 'px'], ascending=[True, True]).reset_index(drop=True)
142
+
143
+ try:
144
+
145
+ wcs = WCS(global_header).dropaxis(2)
146
+
147
+ ratio = global_data.shape[0]/global_data.shape[1]
148
+ # Plot WCS
149
+ fig = plt.figure(figsize=(ratio*scale,scale))
150
+ ax = fig.add_subplot(projection=wcs, label='overlays')
151
+ ax.imshow(global_data, origin='lower')
152
+
153
+ #if not filtered_df is None:
154
+ # filtered_df.plot.scatter(x='px', y='py', ax=ax, s=15, c=patch_color)
155
+
156
+ if "with Grid" in selected_axis_options:
157
+ ax.coords.grid(True, color='white', ls='-', alpha=.5)
158
+ if "with Axis Annotation" in selected_axis_options:
159
+ ax.coords[0].set_axislabel('Right Ascension (J2000)', fontsize=fontsize+2)
160
+ ax.coords[1].set_axislabel('Declination (J2000)', fontsize=fontsize+2)
161
+ else:
162
+ ax.axis('off')
163
+ plt.title(title, fontsize=fontsize+4)
164
+
165
+ if not filtered_df is None:
166
+ all_patches = []
167
+ for i,row in filtered_df.iterrows():
168
+ rect = patches.Rectangle((row.px-patch_size//2, row.py-patch_size//2), patch_size, patch_size, alpha=alpha, linewidth=linewidth, edgecolor=patch_color, facecolor='none')
169
+ ax.add_patch(rect)
170
+ ax.text(row.px,row.py+patch_size//2,str(i+1),
171
+ ha='center',va='bottom',color=patch_color,fontsize=fontsize)
172
+ patch = global_data[row.py-patch_size//2:row.py+patch_size//2,row.px-patch_size//2:row.px+patch_size//2]
173
+ all_patches.append(patch)
174
+ plt.tight_layout()
175
+ # Convert the plot to an image
176
+ buf = io.BytesIO()
177
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200)
178
+ plt.close(fig)
179
+ buf.seek(0)
180
+ # Convert buffer to PIL Image
181
+ image = Image.open(buf)
182
+
183
+ if not filtered_df is None:
184
+
185
+ m = num_rows
186
+ n = int(np.ceil(len(filtered_df)/m))
187
+
188
+ second_scale=max([1,scale//3])
189
+ fig, axarr = plt.subplots(n,m,figsize=(m*second_scale,n*second_scale))
190
+ for i, row in filtered_df.iterrows():
191
+ ax = axarr[i//m,i%m]
192
+ ax.imshow(all_patches[i][::-1])
193
+ ax.set_title(row.main_id, fontsize=fontsize-2)
194
+ ax.set_xticks([])
195
+ ax.set_yticks([])
196
+ ax.text(2,2,str(i+1)[:30],ha='left',va='top',fontsize=fontsize+6)
197
+ for i in np.arange(len(all_patches),m*n):
198
+ ax = axarr[i//m,i%m]
199
+ ax.axis('off')
200
+ plt.tight_layout()
201
+
202
+ # Convert the plot to an image
203
+ second_buf = io.BytesIO()
204
+ plt.savefig(second_buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200)
205
+ plt.close(fig)
206
+ second_buf.seek(0)
207
+ # Convert buffer to PIL Image
208
+ patches_image = Image.open(second_buf)
209
+
210
+ return filtered_df, image, patches_image
211
+ else:
212
+ return filtered_df, image, None
213
+
214
+ except Exception as e:
215
+ return f"Error: {str(e)}"
216
+
217
+
218
+ # Gradio interface
219
+ with gr.Blocks(css=".btn-green {background-color: green; color: white;}") as gui:
220
+ gr.Markdown("# What's in my image?")
221
+ # Options Area
222
+ with gr.Row() as options_gui:
223
+ num_rows = gr.Number(label="Number of Rows", value=16, minimum=2, precision=0, interactive=True)
224
+ title = gr.Textbox(label="Image Title", value="Custom Title", interactive=True)
225
+ patch_size = gr.Slider(label="Patch Size", minimum=16, maximum=128, step=8, value=32,
226
+ interactive=True)
227
+ fontsize = gr.Slider(label="Fontsize", minimum=6, maximum=26, step=1, value=10,
228
+ interactive=True)
229
+ alpha = gr.Slider(label="Alpha", minimum=0., maximum=1., step=.1, value=1.,
230
+ interactive=True)
231
+ linewidth = gr.Slider(label="Linewidth", minimum=1, maximum=4, step=1, value=1,
232
+ interactive=True)
233
+ scale = gr.Slider(label="Scale", minimum=1, maximum=20, step=1, value=8,
234
+ interactive=True)
235
+
236
+ patch_color = gr.ColorPicker(label="Patch Color", value="#FFFFFF", interactive=True)
237
+ sort_method = gr.Dropdown(label="Sorting Method", choices=["by Catalogue", "by x", "by y"], value="by Catalogue", interactive=True)
238
+ axis_options = gr.CheckboxGroup(
239
+ label="Select options",
240
+ choices=["with Grid", "with Axis Annotation"],
241
+ value=["with Grid", "with Axis Annotation"], # Preselected values
242
+ interactive=True # Makes it interactive
243
+ )
244
+ gr.Markdown("Upload a plate solved `.fits` file (32 bit) to display its content.")
245
+
246
+ file_input = gr.File(label="Upload .fits File", type="filepath")
247
+ #file_input_csv = gr.File(label="Upload .csv File")
248
+ greet_button = gr.Button("Query Simbad for Galaxies") # Create the button
249
+ fits_image = gr.Image(label="Input Image", type="pil")
250
+ type_checkboxes = gr.CheckboxGroup(label="Select Catalogue")
251
+ patches_image = gr.Image(label="Patches Image", type="pil")
252
+ csv_table = gr.DataFrame(label="CSV Table")
253
+
254
+ track_options = [type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method]
255
+
256
+ file_input.change(load_fits_image,
257
+ inputs=[file_input] + track_options,
258
+ outputs=[csv_table,fits_image,patches_image])
259
+
260
+ for option_i in track_options:
261
+ option_i.change(update_images_and_tables,
262
+ inputs=track_options,
263
+ outputs=[csv_table,fits_image,patches_image])
264
+
265
+ # Display CSV table
266
+ #file_input_csv.change(show_csv,
267
+ # inputs=file_input_csv,
268
+ # outputs=[csv_table, type_checkboxes])
269
+
270
+ greet_button.click(query_update_table, inputs=None, outputs=[csv_table, type_checkboxes])
271
+
272
+ # Update the selected checkboxes change
273
+ type_checkboxes.change(update_images_and_tables,
274
+ inputs=track_options,
275
+ outputs=[csv_table,fits_image,patches_image])
276
+
277
+ gui.launch(debug=True)
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- accelerate
2
- diffusers
3
  invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
1
  invisible_watermark
2
+ astropy
3
+ astroquery
4
+ matplotlib