aiqcamp commited on
Commit
ee964c4
·
verified ·
1 Parent(s): 72660bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +518 -448
app.py CHANGED
@@ -1,458 +1,528 @@
1
  import gradio as gr
2
- import spaces
3
- import torch
4
- from diffusers import AutoencoderKL, TCDScheduler
5
- from diffusers.models.model_loading_utils import load_state_dict
6
- from gradio_imageslider import ImageSlider
7
- from huggingface_hub import hf_hub_download
8
-
9
- from controlnet_union import ControlNetModel_Union
10
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
11
-
12
- from PIL import Image, ImageDraw
13
- import numpy as np
14
-
15
- config_file = hf_hub_download(
16
- "xinsir/controlnet-union-sdxl-1.0",
17
- filename="config_promax.json",
18
- )
19
-
20
- config = ControlNetModel_Union.load_config(config_file)
21
- controlnet_model = ControlNetModel_Union.from_config(config)
22
- model_file = hf_hub_download(
23
- "xinsir/controlnet-union-sdxl-1.0",
24
- filename="diffusion_pytorch_model_promax.safetensors",
25
- )
26
- state_dict = load_state_dict(model_file)
27
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
28
- controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
29
- )
30
- model.to(device="cuda", dtype=torch.float16)
31
-
32
- vae = AutoencoderKL.from_pretrained(
33
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
34
- ).to("cuda")
35
-
36
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
37
- "SG161222/RealVisXL_V5.0_Lightning",
38
- torch_dtype=torch.float16,
39
- vae=vae,
40
- controlnet=model,
41
- variant="fp16",
42
- ).to("cuda")
43
-
44
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
-
46
-
47
- def can_expand(source_width, source_height, target_width, target_height, alignment):
48
- """Checks if the image can be expanded based on the alignment."""
49
- if alignment in ("Left", "Right") and source_width >= target_width:
50
- return False
51
- if alignment in ("Top", "Bottom") and source_height >= target_height:
52
- return False
53
- return True
54
-
55
- def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
56
- target_size = (width, height)
57
-
58
- # Calculate the scaling factor to fit the image within the target size
59
- scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
60
- new_width = int(image.width * scale_factor)
61
- new_height = int(image.height * scale_factor)
62
-
63
- # Resize the source image to fit within target size
64
- source = image.resize((new_width, new_height), Image.LANCZOS)
65
-
66
- # Apply resize option using percentages
67
- if resize_option == "Full":
68
- resize_percentage = 100
69
- elif resize_option == "50%":
70
- resize_percentage = 50
71
- elif resize_option == "33%":
72
- resize_percentage = 33
73
- elif resize_option == "25%":
74
- resize_percentage = 25
75
- else: # Custom
76
- resize_percentage = custom_resize_percentage
77
-
78
- # Calculate new dimensions based on percentage
79
- resize_factor = resize_percentage / 100
80
- new_width = int(source.width * resize_factor)
81
- new_height = int(source.height * resize_factor)
82
-
83
- # Ensure minimum size of 64 pixels
84
- new_width = max(new_width, 64)
85
- new_height = max(new_height, 64)
86
-
87
- # Resize the image
88
- source = source.resize((new_width, new_height), Image.LANCZOS)
89
-
90
- # Calculate the overlap in pixels based on the percentage
91
- overlap_x = int(new_width * (overlap_percentage / 100))
92
- overlap_y = int(new_height * (overlap_percentage / 100))
93
-
94
- # Ensure minimum overlap of 1 pixel
95
- overlap_x = max(overlap_x, 1)
96
- overlap_y = max(overlap_y, 1)
97
-
98
- # Calculate margins based on alignment
99
- if alignment == "Middle":
100
- margin_x = (target_size[0] - new_width) // 2
101
- margin_y = (target_size[1] - new_height) // 2
102
- elif alignment == "Left":
103
- margin_x = 0
104
- margin_y = (target_size[1] - new_height) // 2
105
- elif alignment == "Right":
106
- margin_x = target_size[0] - new_width
107
- margin_y = (target_size[1] - new_height) // 2
108
- elif alignment == "Top":
109
- margin_x = (target_size[0] - new_width) // 2
110
- margin_y = 0
111
- elif alignment == "Bottom":
112
- margin_x = (target_size[0] - new_width) // 2
113
- margin_y = target_size[1] - new_height
114
-
115
- # Adjust margins to eliminate gaps
116
- margin_x = max(0, min(margin_x, target_size[0] - new_width))
117
- margin_y = max(0, min(margin_y, target_size[1] - new_height))
118
-
119
- # Create a new background image and paste the resized source image
120
- background = Image.new('RGB', target_size, (255, 255, 255))
121
- background.paste(source, (margin_x, margin_y))
122
-
123
- # Create the mask
124
- mask = Image.new('L', target_size, 255)
125
- mask_draw = ImageDraw.Draw(mask)
126
-
127
- # Calculate overlap areas
128
- white_gaps_patch = 2
129
-
130
- left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
131
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
132
- top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
133
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
134
-
135
- if alignment == "Left":
136
- left_overlap = margin_x + overlap_x if overlap_left else margin_x
137
- elif alignment == "Right":
138
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
139
- elif alignment == "Top":
140
- top_overlap = margin_y + overlap_y if overlap_top else margin_y
141
- elif alignment == "Bottom":
142
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
143
-
144
-
145
- # Draw the mask
146
- mask_draw.rectangle([
147
- (left_overlap, top_overlap),
148
- (right_overlap, bottom_overlap)
149
- ], fill=0)
150
-
151
- return background, mask
152
-
153
- def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
154
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
155
-
156
- # Create a preview image showing the mask
157
- preview = background.copy().convert('RGBA')
158
-
159
- # Create a semi-transparent red overlay
160
- red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
161
-
162
- # Convert black pixels in the mask to semi-transparent red
163
- red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
164
- red_mask.paste(red_overlay, (0, 0), mask)
165
-
166
- # Overlay the red mask on the background
167
- preview = Image.alpha_composite(preview, red_mask)
168
-
169
- return preview
170
-
171
- @spaces.GPU(duration=24)
172
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
173
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
174
-
175
- if not can_expand(background.width, background.height, width, height, alignment):
176
- alignment = "Middle"
177
-
178
- cnet_image = background.copy()
179
- cnet_image.paste(0, (0, 0), mask)
180
-
181
- final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
182
-
183
- # Use with torch.autocast to ensure consistent dtype
184
- with torch.autocast(device_type="cuda", dtype=torch.float16):
185
- (
186
- prompt_embeds,
187
- negative_prompt_embeds,
188
- pooled_prompt_embeds,
189
- negative_pooled_prompt_embeds,
190
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
191
-
192
- for image in pipe(
193
- prompt_embeds=prompt_embeds,
194
- negative_prompt_embeds=negative_prompt_embeds,
195
- pooled_prompt_embeds=pooled_prompt_embeds,
196
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
197
- image=cnet_image,
198
- num_inference_steps=num_inference_steps
199
- ):
200
- yield cnet_image, image
201
-
202
- image = image.convert("RGBA")
203
- cnet_image.paste(image, (0, 0), mask)
204
-
205
- yield background, cnet_image
206
-
207
- def clear_result():
208
- """Clears the result ImageSlider."""
209
- return gr.update(value=None)
210
-
211
- def preload_presets(target_ratio, ui_width, ui_height):
212
- """Updates the width and height sliders based on the selected aspect ratio."""
213
- if target_ratio == "9:16":
214
- changed_width = 720
215
- changed_height = 1280
216
- return changed_width, changed_height, gr.update()
217
- elif target_ratio == "16:9":
218
- changed_width = 1280
219
- changed_height = 720
220
- return changed_width, changed_height, gr.update()
221
- elif target_ratio == "1:1":
222
- changed_width = 1024
223
- changed_height = 1024
224
- return changed_width, changed_height, gr.update()
225
- elif target_ratio == "Custom":
226
- return ui_width, ui_height, gr.update(open=True)
227
-
228
- def select_the_right_preset(user_width, user_height):
229
- if user_width == 720 and user_height == 1280:
230
- return "9:16"
231
- elif user_width == 1280 and user_height == 720:
232
- return "16:9"
233
- elif user_width == 1024 and user_height == 1024:
234
- return "1:1"
235
  else:
236
- return "Custom"
237
-
238
- def toggle_custom_resize_slider(resize_option):
239
- return gr.update(visible=(resize_option == "Custom"))
240
-
241
- def update_history(new_image, history):
242
- """Updates the history gallery with the new image."""
243
- if history is None:
244
- history = []
245
- history.insert(0, new_image)
246
- return history
247
-
248
- css = """
249
- .gradio-container {
250
- width: 1200px !important;
251
- }
252
- """
253
-
254
- # Define the title HTML string
255
- title = """<h1 align="center">Diffusers Image Outpaint</h1>
256
- <div align="center">Drop an image you would like to extend, pick your expected ratio and hit Generate.</div>
257
- <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
258
- <p style="display: flex;gap: 6px;">
259
- <a href="https://github.com/huggingface/diffusers">
260
- Built with Diffusers
261
- </a>
262
- </p>
263
- </div>
264
- """
265
-
266
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  with gr.Column():
268
- gr.HTML(title)
269
-
 
 
 
 
 
 
270
  with gr.Row():
271
- with gr.Column():
272
- input_image = gr.Image(
273
- type="pil",
274
- label="Input Image"
275
- )
276
-
277
- with gr.Row():
278
- with gr.Column(scale=2):
279
- prompt_input = gr.Textbox(label="Prompt (Optional)")
280
- with gr.Column(scale=1):
281
- run_button = gr.Button("Generate")
282
-
283
- with gr.Row():
284
- target_ratio = gr.Radio(
285
- label="Expected Ratio",
286
- choices=["9:16", "16:9", "1:1", "Custom"],
287
- value="9:16",
288
- scale=2
289
- )
290
-
291
- alignment_dropdown = gr.Dropdown(
292
- choices=["Middle", "Left", "Right", "Top", "Bottom"],
293
- value="Middle",
294
- label="Alignment"
295
- )
296
 
297
- with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
 
298
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  with gr.Row():
300
- width_slider = gr.Slider(
301
- label="Target Width",
302
- minimum=720,
303
- maximum=1536,
304
- step=8,
305
- value=720, # Set a default value
306
- )
307
- height_slider = gr.Slider(
308
- label="Target Height",
309
- minimum=720,
310
- maximum=1536,
311
- step=8,
312
- value=1280, # Set a default value
313
- )
314
-
315
- num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
316
- with gr.Group():
317
- overlap_percentage = gr.Slider(
318
- label="Mask overlap (%)",
319
- minimum=1,
320
- maximum=50,
321
- value=10,
322
- step=1
323
- )
324
- with gr.Row():
325
- overlap_top = gr.Checkbox(label="Overlap Top", value=True)
326
- overlap_right = gr.Checkbox(label="Overlap Right", value=True)
327
- with gr.Row():
328
- overlap_left = gr.Checkbox(label="Overlap Left", value=True)
329
- overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
330
  with gr.Row():
331
- resize_option = gr.Radio(
332
- label="Resize input image",
333
- choices=["Full", "50%", "33%", "25%", "Custom"],
334
- value="Full"
335
- )
336
- custom_resize_percentage = gr.Slider(
337
- label="Custom resize (%)",
338
- minimum=1,
339
- maximum=100,
340
- step=1,
341
- value=50,
342
- visible=False
343
- )
344
-
345
- with gr.Column():
346
- preview_button = gr.Button("Preview alignment and mask")
347
-
348
-
349
- gr.Examples(
350
- examples=[
351
- ["./examples/example_2.jpg", 1440, 810, "Left"],
352
- ["./examples/example_3.jpg", 1024, 1024, "Top"],
353
- ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
354
- ],
355
- inputs=[input_image, width_slider, height_slider, alignment_dropdown],
356
- )
357
-
358
-
359
-
360
- with gr.Column():
361
- result = ImageSlider(
362
- interactive=False,
363
- label="Generated Image",
364
- )
365
- use_as_input_button = gr.Button("Use as Input Image", visible=False)
366
-
367
- history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
368
- preview_image = gr.Image(label="Preview")
369
-
370
-
371
-
372
- def use_output_as_input(output_image):
373
- """Sets the generated output as the new input image."""
374
- return gr.update(value=output_image[1])
375
-
376
- use_as_input_button.click(
377
- fn=use_output_as_input,
378
- inputs=[result],
379
- outputs=[input_image]
380
- )
381
-
382
- target_ratio.change(
383
- fn=preload_presets,
384
- inputs=[target_ratio, width_slider, height_slider],
385
- outputs=[width_slider, height_slider, settings_panel],
386
- queue=False
387
- )
388
-
389
- width_slider.change(
390
- fn=select_the_right_preset,
391
- inputs=[width_slider, height_slider],
392
- outputs=[target_ratio],
393
- queue=False
394
- )
395
-
396
- height_slider.change(
397
- fn=select_the_right_preset,
398
- inputs=[width_slider, height_slider],
399
- outputs=[target_ratio],
400
- queue=False
401
- )
402
-
403
- resize_option.change(
404
- fn=toggle_custom_resize_slider,
405
- inputs=[resize_option],
406
- outputs=[custom_resize_percentage],
407
- queue=False
408
- )
409
-
410
- run_button.click( # Clear the result
411
- fn=clear_result,
412
- inputs=None,
413
- outputs=result,
414
- ).then( # Generate the new image
415
- fn=infer,
416
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
417
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
418
- overlap_left, overlap_right, overlap_top, overlap_bottom],
419
- outputs=result,
420
- ).then( # Update the history gallery
421
- fn=lambda x, history: update_history(x[1], history),
422
- inputs=[result, history_gallery],
423
- outputs=history_gallery,
424
- ).then( # Show the "Use as Input Image" button
425
- fn=lambda: gr.update(visible=True),
426
- inputs=None,
427
- outputs=use_as_input_button,
428
- )
429
-
430
- prompt_input.submit( # Clear the result
431
- fn=clear_result,
432
- inputs=None,
433
- outputs=result,
434
- ).then( # Generate the new image
435
- fn=infer,
436
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
437
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
438
- overlap_left, overlap_right, overlap_top, overlap_bottom],
439
- outputs=result,
440
- ).then( # Update the history gallery
441
- fn=lambda x, history: update_history(x[1], history),
442
- inputs=[result, history_gallery],
443
- outputs=history_gallery,
444
- ).then( # Show the "Use as Input Image" button
445
- fn=lambda: gr.update(visible=True),
446
- inputs=None,
447
- outputs=use_as_input_button,
448
- )
449
 
450
- preview_button.click(
451
- fn=preview_image_and_mask,
452
- inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
453
- overlap_left, overlap_right, overlap_top, overlap_bottom],
454
- outputs=preview_image,
455
- queue=False
456
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
- demo.queue(max_size=12).launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import datetime
3
+ import json
4
+ import random
5
+ import requests
6
+ from constants import *
7
+
8
+ def process(query_type, index_desc, **kwargs):
9
+ timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
10
+ index = INDEX_BY_DESC[index_desc]
11
+ data = {
12
+ 'source': 'hf' if not DEBUG else 'hf-dev',
13
+ 'timestamp': timestamp,
14
+ 'query_type': query_type,
15
+ 'index': index,
16
+ }
17
+ data.update(kwargs)
18
+ print(json.dumps(data))
19
+ if API_URL is None:
20
+ raise ValueError(f'API_URL envvar is not set!')
21
+ try:
22
+ response = requests.post(API_URL, json=data, timeout=10)
23
+ except requests.exceptions.Timeout:
24
+ raise ValueError('Web request timed out. Please try again later.')
25
+ except requests.exceptions.RequestException as e:
26
+ raise ValueError(f'Web request error: {e}')
27
+ if response.status_code == 200:
28
+ result = response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
+ raise ValueError(f'HTTP error {response.status_code}: {response.json()}')
31
+ if DEBUG:
32
+ print(result)
33
+ return result
34
+
35
+ def format_tokenization_info(result):
36
+ if not ('token_ids' in result and 'tokens' in result):
37
+ return ''
38
+ token_ids = result['token_ids']
39
+ tokens = result['tokens']
40
+ if type(token_ids) == list and all([type(token_id) == int for token_id in token_ids]):
41
+ output = '[' + " ".join(['"' + token.replace('Ġ', ' ') + '"' for token in tokens]) + '] ' + str(token_ids)
42
+ else:
43
+ ttt = []
44
+ for token_idss, tokenss in zip(token_ids, tokens):
45
+ tt = []
46
+ for token_ids, tokens in zip(token_idss, tokenss):
47
+ t = '[' + " ".join(['"' + token.replace('Ġ', ' ') + '"' for token in tokens]) + '] ' + str(token_ids)
48
+ tt.append(t)
49
+ tt = '\n'.join(tt)
50
+ ttt.append(tt)
51
+ output = '\n\n'.join(ttt)
52
+ return output
53
+ def format_doc_metadata(doc):
54
+ formatted = f'Document #{doc["doc_ix"]}\n'
55
+ if doc['doc_len'] == doc['disp_len']:
56
+ formatted += f'Length: {doc["doc_len"]} tokens\n'
57
+ else:
58
+ formatted += f'Length: {doc["doc_len"]} tokens ({doc["disp_len"]} tokens displayed)\n'
59
+ metadata = doc['metadata'].strip("\n")
60
+ formatted += f'Metadata: {metadata}'
61
+ return formatted
62
+
63
+ def count(index_desc, query, max_clause_freq, max_diff_tokens):
64
+ if ' AND ' in query or ' OR ' in query: # CNF query
65
+ result = process('count', index_desc, query=query, max_clause_freq=max_clause_freq, max_diff_tokens=max_diff_tokens)
66
+ else: # simple query
67
+ result = process('count', index_desc, query=query)
68
+ latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
69
+ tokenization_info = format_tokenization_info(result)
70
+ if 'error' in result:
71
+ count = result['error']
72
+ else:
73
+ count = f'{result["count"]:,}'
74
+ return latency, tokenization_info, count
75
+
76
+ def prob(index_desc, query):
77
+ result = process('prob', index_desc, query=query)
78
+ latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
79
+ tokenization_info = format_tokenization_info(result)
80
+ if 'error' in result:
81
+ prob = result['error']
82
+ elif result['prompt_cnt'] == 0:
83
+ prob = '(n-1)-gram is not found in the corpus'
84
+ else:
85
+ prob = f'{result["prob"]:.4f} ({result["cont_cnt"]:,} / {result["prompt_cnt"]:,})'
86
+ return latency, tokenization_info, prob
87
+
88
+ def ntd(index_desc, query, max_support):
89
+ result = process('ntd', index_desc, query=query, max_support=max_support)
90
+ latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
91
+ tokenization_info = format_tokenization_info(result)
92
+ if 'error' in result:
93
+ ntd = result['error']
94
+ else:
95
+ result_by_token_id = result['result_by_token_id']
96
+ ntd = {}
97
+ for token_id, r in result_by_token_id.items():
98
+ ntd[f'{r["token"]} ({r["cont_cnt"]} / {result["prompt_cnt"]})'] = r['prob']
99
+ if ntd == {}:
100
+ ntd = '(n-1)-gram is not found in the corpus'
101
+ return latency, tokenization_info, ntd
102
+
103
+ def infgram_prob(index_desc, query):
104
+ result = process('infgram_prob', index_desc, query=query)
105
+ latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
106
+ tokenization_info = format_tokenization_info(result)
107
+ if 'error' in result:
108
+ longest_suffix = ''
109
+ prob = result['error']
110
+ else:
111
+ longest_suffix = result['longest_suffix']
112
+ prob = f'{result["prob"]:.4f} ({result["cont_cnt"]:,} / {result["prompt_cnt"]:,})'
113
+ return latency, tokenization_info, longest_suffix, prob
114
+
115
+ def infgram_ntd(index_desc, query, max_support):
116
+ result = process('infgram_ntd', index_desc, query=query, max_support=max_support)
117
+ latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
118
+ tokenization_info = format_tokenization_info(result)
119
+ if 'error' in result:
120
+ longest_suffix = ''
121
+ ntd = result['error']
122
+ else:
123
+ longest_suffix = result['longest_suffix']
124
+ result_by_token_id = result['result_by_token_id']
125
+ ntd = {}
126
+ for token_id, r in result_by_token_id.items():
127
+ ntd[f'{r["token"]} ({r["cont_cnt"]} / {result["prompt_cnt"]})'] = r['prob']
128
+ return latency, tokenization_info, longest_suffix, ntd
129
+
130
+ def search_docs(index_desc, query, maxnum, max_disp_len, max_clause_freq, max_diff_tokens):
131
+ if ' AND ' in query or ' OR ' in query: # CNF query
132
+ result = process('search_docs', index_desc, query=query, maxnum=maxnum, max_disp_len=max_disp_len, max_clause_freq=max_clause_freq, max_diff_tokens=max_diff_tokens)
133
+ else: # simple query
134
+ result = process('search_docs', index_desc, query=query, maxnum=maxnum, max_disp_len=max_disp_len)
135
+ latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
136
+ tokenization_info = format_tokenization_info(result)
137
+ if 'error' in result:
138
+ message = result['error']
139
+ metadatas = ['' for _ in range(MAXNUM)]
140
+ docs = [[] for _ in range(MAXNUM)]
141
+ else:
142
+ message = result['message']
143
+ metadatas = [format_doc_metadata(doc) for doc in result['documents']]
144
+ docs = [doc['spans'] for doc in result['documents']]
145
+ metadatas = metadatas[:maxnum]
146
+ docs = docs[:maxnum]
147
+ while len(metadatas) < MAXNUM:
148
+ metadatas.append('')
149
+ while len(docs) < MAXNUM:
150
+ docs.append([])
151
+ return tuple([latency, tokenization_info, message] + metadatas + docs)
152
+
153
+ def search_docs_new(index_desc, query, max_disp_len, max_clause_freq, max_diff_tokens, state):
154
+ if ' AND ' in query or ' OR ' in query: # CNF query
155
+ find_result = process('find_cnf', index_desc, query=query, max_clause_freq=max_clause_freq, max_diff_tokens=max_diff_tokens)
156
+ find_result['type'] = 'cnf'
157
+ else: # simple query
158
+ find_result = process('find', index_desc, query=query)
159
+ find_result['type'] = 'simple'
160
+
161
+ state = find_result
162
+
163
+ latency = '' if 'latency' not in find_result else f'{find_result["latency"]:.3f}'
164
+ tokenization_info = format_tokenization_info(find_result)
165
+ if 'error' in find_result:
166
+ message = find_result['error']
167
+ idx = gr.Number(minimum=0, maximum=0, step=1, value=0, interactive=False)
168
+ metadata = ''
169
+ doc = []
170
+ return latency, tokenization_info, message, idx, metadata, doc, state
171
+
172
+ if ' AND ' in query or ' OR ' in query: # CNF query
173
+ ptrs_by_shard = find_result['ptrs_by_shard']
174
+ cnt_retrievable = sum([len(ptrs) for ptrs in ptrs_by_shard])
175
+ if find_result["approx"]:
176
+ message = f'Approximately {find_result["cnt"]} occurrences found, of which {cnt_retrievable} are retrievable'
177
+ else:
178
+ message = f'{find_result["cnt"]} occurrences found'
179
+ else: # simple query
180
+ message = f'{find_result["cnt"]} occurrences found'
181
+ cnt_retrievable = find_result['cnt']
182
+ if cnt_retrievable == 0:
183
+ idx = gr.Number(minimum=0, maximum=0, step=1, value=0, interactive=False)
184
+ metadata = ''
185
+ doc = []
186
+ return latency, tokenization_info, message, idx, metadata, doc, state
187
+ idx = random.randint(0, cnt_retrievable-1)
188
+ metadata, doc = get_another_doc(index_desc, idx, max_disp_len, state)
189
+ idx = gr.Number(minimum=0, maximum=cnt_retrievable-1, step=1, value=idx, interactive=True)
190
+ return latency, tokenization_info, message, idx, metadata, doc, state
191
+
192
+ def clear_search_docs_new(state):
193
+ state = None
194
+ idx = gr.Number(minimum=0, maximum=0, step=1, value=0, interactive=False)
195
+ return idx, state
196
+
197
+ def get_another_doc(index_desc, idx, max_disp_len, state):
198
+ find_result = state
199
+ if find_result is None or not (type(idx) == int and 0 <= idx and idx < find_result['cnt']):
200
+ metadata = ''
201
+ doc = []
202
+ return metadata, doc
203
+ if find_result['type'] == 'cnf':
204
+ ptrs_by_shard = find_result['ptrs_by_shard']
205
+ cnt_by_shard = [len(ptrs) for ptrs in ptrs_by_shard]
206
+ s = 0
207
+ while idx >= cnt_by_shard[s]:
208
+ idx -= cnt_by_shard[s]
209
+ s += 1
210
+ ptr = ptrs_by_shard[s][idx]
211
+ result = process('get_doc_by_ptr', index_desc, s=s, ptr=ptr, max_disp_len=max_disp_len, query_ids=find_result['token_ids'])
212
+ else: # simple query
213
+ segment_by_shard = find_result['segment_by_shard']
214
+ cnt_by_shard = [end - start for (start, end) in segment_by_shard]
215
+ s = 0
216
+ while idx >= cnt_by_shard[s]:
217
+ idx -= cnt_by_shard[s]
218
+ s += 1
219
+ rank = segment_by_shard[s][0] + idx
220
+ result = process('get_doc_by_rank', index_desc, s=s, rank=rank, max_disp_len=max_disp_len, query_ids=find_result['token_ids'])
221
+ if 'error' in result:
222
+ metadata = result['error']
223
+ doc = []
224
+ return metadata, doc
225
+ metadata = format_doc_metadata(result)
226
+ doc = result['spans']
227
+ return metadata, doc
228
+
229
+ with gr.Blocks() as demo:
230
  with gr.Column():
231
+ gr.HTML(
232
+ '''<h1 text-align="center">Infini-gram: An Efficient Search Engine over the Massive Pretraining Datasets of Language Models</h1>
233
+ <p style='font-size: 16px;'>This engine does exact-match search over several open pretraining datasets of language models. Please first select the corpus and the type of query, then enter your query and submit.</p>
234
+ <p style='font-size: 16px;'>The engine is developed by <a href="https://liujch1998.github.io">Jiacheng Liu</a> and documented in our paper: <a href="https://huggingface.co/papers/2401.17377">Infini-gram: Scaling Unbounded n-gram Language Models to a Trillion Tokens</a>. Feel free to check out our <a href="https://infini-gram.io">Project Homepage</a>.</p>
235
+ <p style='font-size: 16px;'><b>API Endpoint:</b> If you'd like to issue batch queries to infini-gram, you may invoke our API endpoint. Please refer to the <a href="https://infini-gram.io/api_doc">API documentation</a>.</p>
236
+ <p style='font-size: 16px;'><b>Note:</b> The query is <b>case-sensitive</b>. Your query will be tokenized with the Llama-2 tokenizer (unless otherwise specified).</p>
237
+ '''
238
+ )
239
  with gr.Row():
240
+ with gr.Column(scale=1, min_width=240):
241
+ index_desc = gr.Radio(choices=INDEX_DESCS, label='Corpus', value=INDEX_DESCS[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ with gr.Column(scale=7):
244
+ with gr.Tab('1. Count an n-gram'):
245
  with gr.Column():
246
+ gr.HTML('<h2>1. Count an n-gram</h2>')
247
+ with gr.Accordion(label='Click to view instructions', open=False):
248
+ gr.HTML(f'''<p style="font-size: 16px;">This counts the number of times an n-gram appears in the corpus. If you submit an empty input, it will return the total number of tokens in the corpus. You can also make more complex queries by connecting multiple n-gram terms with the AND/OR operators, in the <a href="https://en.wikipedia.org/wiki/Conjunctive_normal_form">CNF format</a>.</p>
249
+ <br>
250
+ <p style="font-size: 16px;">Example queries:</p>
251
+ <ul style="font-size: 16px;">
252
+ <li><b>natural language processing</b> (the output is number of occurrences of "natural language processing")</li>
253
+ <li><b>natural language processing AND deep learning</b> (the output is the number of co-occurrences of "natural language processing" and "deep learning")</li>
254
+ <li><b>natural language processing OR artificial intelligence AND deep learning OR machine learning</b> (the output is the number of co-occurrences of [one of "natural language processing" / "artificial intelligence"] and [one of "deep learning" / "machine learning"])</li>
255
+ </ul>
256
+ <br>
257
+ <p style="font-size: 16px;">Notes on CNF queries:</p>
258
+ <ul style="font-size: 16px;">
259
+ <li>A CNF query may contain up to {MAX_CLAUSES_PER_CNF} clauses, and each clause may contain up to {MAX_TERMS_PER_CLAUSE} n-gram terms.</li>
260
+ <li>When you write a query in CNF, note that <b>OR has higher precedence than AND</b> (which is contrary to conventions in boolean algebra).</li>
261
+ <li>In AND queries, we can only examine co-occurrences where adjacent clauses are separated by no more than {max_diff_tokens} tokens. This value can be adjusted within range [1, {MAX_DIFF_TOKENS}] in "Advanced options".</li>
262
+ <li>In AND queries, if a clause has more than {max_clause_freq} matches, we will estimate the count by examining a random subset of {max_clause_freq} occurrences of clause. This value can be adjusted within range [1, {MAX_CLAUSE_FREQ}] in "Advanced options".</li>
263
+ <li>The above subsampling mechanism might cause a zero count on co-occurrences of some simple n-grams (e.g., <b>birds AND oil</b>).</li>
264
+ </ul>
265
+ ''')
266
  with gr.Row():
267
+ with gr.Column(scale=1):
268
+ count_query = gr.Textbox(placeholder='Enter a string (an n-gram) here', label='Query', interactive=True)
269
+ with gr.Accordion(label='Advanced options', open=False):
270
+ with gr.Row():
271
+ count_max_clause_freq = gr.Slider(minimum=1, maximum=MAX_CLAUSE_FREQ, value=max_clause_freq, step=1, label='max_clause_freq')
272
+ count_max_diff_tokens = gr.Slider(minimum=1, maximum=MAX_DIFF_TOKENS, value=max_diff_tokens, step=1, label='max_diff_tokens')
273
+ with gr.Row():
274
+ count_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
275
+ count_submit = gr.Button(value='Submit', variant='primary', visible=True)
276
+ count_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
277
+ count_tokenized = gr.Textbox(label='Tokenized', lines=1, interactive=False)
278
+ with gr.Column(scale=1):
279
+ count_count = gr.Label(label='Count', num_top_classes=0)
280
+ count_clear.add([count_query, count_latency, count_tokenized, count_count])
281
+ count_submit.click(count, inputs=[index_desc, count_query, count_max_clause_freq, count_max_diff_tokens], outputs=[count_latency, count_tokenized, count_count], api_name=False)
282
+
283
+ with gr.Tab('2. Prob of the last token'):
284
+ with gr.Column():
285
+ gr.HTML('<h2>2. Compute the probability of the last token in an n-gram</h2>')
286
+ with gr.Accordion(label='Click to view instructions', open=False):
287
+ gr.HTML(f'''<p style="font-size: 16px;">This computes the n-gram probability of the last token conditioned on the previous tokens (i.e. (n-1)-gram)).</p>
288
+ <br>
289
+ <p style="font-size: 16px;">Example query: <b>natural language processing</b> (the output is P(processing | natural language), by counting the appearance of the 3-gram "natural language processing" and the 2-gram "natural language", and take the division between the two)</p>
290
+ <br>
291
+ <p style="font-size: 16px;">Notes:</p>
292
+ <ul style="font-size: 16px;">
293
+ <li>The (n-1)-gram needs to exist in the corpus. If the (n-1)-gram is not found in the corpus, an error message will appear.</li>
294
+ </ul>
295
+ ''')
 
296
  with gr.Row():
297
+ with gr.Column(scale=1):
298
+ prob_query = gr.Textbox(placeholder='Enter a string (an n-gram) here', label='Query', interactive=True)
299
+ with gr.Row():
300
+ prob_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
301
+ prob_submit = gr.Button(value='Submit', variant='primary', visible=True)
302
+ prob_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
303
+ prob_tokenized = gr.Textbox(label='Tokenized', lines=1, interactive=False)
304
+ with gr.Column(scale=1):
305
+ prob_probability = gr.Label(label='Probability', num_top_classes=0)
306
+ prob_clear.add([prob_query, prob_latency, prob_tokenized, prob_probability])
307
+ prob_submit.click(prob, inputs=[index_desc, prob_query], outputs=[prob_latency, prob_tokenized, prob_probability], api_name=False)
308
+
309
+ with gr.Tab('3. Next-token distribution'):
310
+ with gr.Column():
311
+ gr.HTML('<h2>3. Compute the next-token distribution of an (n-1)-gram</h2>')
312
+ with gr.Accordion(label='Click to view instructions', open=False):
313
+ gr.HTML(f'''<p style="font-size: 16px;">This is an extension of the Query Type 2: It interprets your input as the (n-1)-gram and gives you the full next-token distribution.</p>
314
+ <br>
315
+ <p style="font-size: 16px;">Example query: <b>natural language</b> (the output is P(* | natural language), for the top-10 tokens *)</p>
316
+ <br>
317
+ <p style="font-size: 16px;">Notes:</p>
318
+ <ul style="font-size: 16px;">
319
+ <li>The (n-1)-gram needs to exist in the corpus. If the (n-1)-gram is not found in the corpus, an error message will appear.</li>
320
+ <li>If the (n-1)-gram appears more than {max_support} times in the corpus, the result will be approximate: we will estimate the distribution by examining a subset of {max_support} occurrences of the (n-1)-gram. This value can be adjusted within range [1, {MAX_SUPPORT}] in "Advanced options".</li>
321
+ </ul>
322
+ ''')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ with gr.Row():
325
+ with gr.Column(scale=1):
326
+ ntd_query = gr.Textbox(placeholder='Enter a string (an (n-1)-gram) here', label='Query', interactive=True)
327
+ with gr.Accordion(label='Advanced options', open=False):
328
+ ntd_max_support = gr.Slider(minimum=1, maximum=MAX_SUPPORT, value=MAX_SUPPORT, step=1, label='max_support')
329
+ with gr.Row():
330
+ ntd_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
331
+ ntd_submit = gr.Button(value='Submit', variant='primary', visible=True)
332
+ ntd_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
333
+ ntd_tokenized = gr.Textbox(label='Tokenized', lines=1, interactive=False)
334
+ with gr.Column(scale=1):
335
+ ntd_distribution = gr.Label(label='Distribution', num_top_classes=10)
336
+ ntd_clear.add([ntd_query, ntd_latency, ntd_tokenized, ntd_distribution])
337
+ ntd_submit.click(ntd, inputs=[index_desc, ntd_query, ntd_max_support], outputs=[ntd_latency, ntd_tokenized, ntd_distribution], api_name=False)
338
+
339
+ with gr.Tab('4. ∞-gram prob'):
340
+ with gr.Column():
341
+ gr.HTML('<h2>4. Compute the ∞-gram probability of the last token</h2>')
342
+ with gr.Accordion(label='Click to view instructions', open=False):
343
+ gr.HTML(f'''<p style="font-size: 16px;">This computes the ∞-gram probability of the last token conditioned on the previous tokens. Compared to Query Type 2 (which uses your entire input for n-gram modeling), here we take the longest suffix that we can find in the corpus.</p>
344
+ <br>
345
+ <p style="font-size: 16px;">Example query: <b>I love natural language processing</b> (if "natural language" appears in the corpus but "love natural language" doesn't, the output is P(processing | natural language); in this case the effective n = 3)</p>
346
+ <br>
347
+ <p style="font-size: 16px;">Notes:</p>
348
+ <ul style="font-size: 16px;">
349
+ <li>It may be possible that the effective n = 1, i.e. longest found suffix is empty, in which case it reduces to the uni-gram probability of the last token.</li>
350
+ </ul>
351
+ ''')
352
+ with gr.Row():
353
+ with gr.Column(scale=1):
354
+ infgram_prob_query = gr.Textbox(placeholder='Enter a string here', label='Query', interactive=True)
355
+ with gr.Row():
356
+ infgram_prob_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
357
+ infgram_prob_submit = gr.Button(value='Submit', variant='primary', visible=True)
358
+ infgram_prob_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
359
+ infgram_prob_tokenized = gr.Textbox(label='Tokenized', lines=1, interactive=False)
360
+ infgram_prob_longest_suffix = gr.Textbox(label='Longest Found Suffix', interactive=False)
361
+ with gr.Column(scale=1):
362
+ infgram_prob_probability = gr.Label(label='Probability', num_top_classes=0)
363
+ infgram_prob_clear.add([infgram_prob_query, infgram_prob_latency, infgram_prob_tokenized, infgram_prob_longest_suffix, infgram_prob_probability])
364
+ infgram_prob_submit.click(infgram_prob, inputs=[index_desc, infgram_prob_query], outputs=[infgram_prob_latency, infgram_prob_tokenized, infgram_prob_longest_suffix, infgram_prob_probability], api_name=False)
365
+
366
+ with gr.Tab('5. ∞-gram next-token distribution'):
367
+ with gr.Column():
368
+ gr.HTML('<h2>5. Compute the ∞-gram next-token distribution</h2>')
369
+ with gr.Accordion(label='Click to view instructions', open=False):
370
+ gr.HTML(f'''<p style="font-size: 16px;">This is similar to Query Type 3, but with ∞-gram instead of n-gram.</p>
371
+ <br>
372
+ <p style="font-size: 16px;">Example query: <b>I love natural language</b> (if "natural language" appears in the corpus but "love natural language" doesn't, the output is P(* | natural language), for the top-10 tokens *)</p>
373
+ ''')
374
+ with gr.Row():
375
+ with gr.Column(scale=1):
376
+ infgram_ntd_query = gr.Textbox(placeholder='Enter a string here', label='Query', interactive=True)
377
+ with gr.Accordion(label='Advanced options', open=False):
378
+ infgram_ntd_max_support = gr.Slider(minimum=1, maximum=MAX_SUPPORT, value=MAX_SUPPORT, step=1, label='max_support')
379
+ with gr.Row():
380
+ infgram_ntd_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
381
+ infgram_ntd_submit = gr.Button(value='Submit', variant='primary', visible=True)
382
+ infgram_ntd_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
383
+ infgram_ntd_tokenized = gr.Textbox(label='Tokenized', lines=1, interactive=False)
384
+ infgram_ntd_longest_suffix = gr.Textbox(label='Longest Found Suffix', interactive=False)
385
+ with gr.Column(scale=1):
386
+ infgram_ntd_distribution = gr.Label(label='Distribution', num_top_classes=10)
387
+ infgram_ntd_clear.add([infgram_ntd_query, infgram_ntd_latency, infgram_ntd_tokenized, infgram_ntd_longest_suffix, infgram_ntd_distribution])
388
+ infgram_ntd_submit.click(infgram_ntd, inputs=[index_desc, infgram_ntd_query, infgram_ntd_max_support], outputs=[infgram_ntd_latency, infgram_ntd_tokenized, infgram_ntd_longest_suffix, infgram_ntd_distribution], api_name=False)
389
+
390
+ with gr.Tab('6. Search documents', visible=False):
391
+ with gr.Column():
392
+ gr.HTML(f'''<h2>6. Search for documents containing n-gram(s)</h2>''')
393
+ with gr.Accordion(label='Click to view instructions', open=False):
394
+ gr.HTML(f'''<p style="font-size: 16px;">This displays a few random documents in the corpus that satisfies your query. You can simply enter an n-gram, in which case the document displayed would contain your n-gram. You can also connect multiple n-gram terms with the AND/OR operators, in the <a href="https://en.wikipedia.org/wiki/Conjunctive_normal_form">CNF format</a>, in which case the displayed document contains n-grams such that it satisfies this logical constraint.</p>
395
+ <br>
396
+ <p style="font-size: 16px;">Example queries:</p>
397
+ <ul style="font-size: 16px;">
398
+ <li><b>natural language processing</b> (the displayed document would contain "natural language processing")</li>
399
+ <li><b>natural language processing AND deep learning</b> (the displayed document would contain both "natural language processing" and "deep learning")</li>
400
+ <li><b>natural language processing OR artificial intelligence AND deep learning OR machine learning</b> (the displayed document would contain at least one of "natural language processing" / "artificial intelligence", and also at least one of "deep learning" / "machine learning")</li>
401
+ </ul>
402
+ <br>
403
+ <p style="font-size: 16px;">If you want another batch of random documents, simply hit the Submit button again :)</p>
404
+ <br>
405
+ <p style="font-size: 16px;">Notes on CNF queries:</p>
406
+ <ul style="font-size: 16px;">
407
+ <li>A CNF query may contain up to {MAX_CLAUSES_PER_CNF} clauses, and each clause may contain up to {MAX_TERMS_PER_CLAUSE} n-gram terms.</li>
408
+ <li>When you write a query in CNF, note that <b>OR has higher precedence than AND</b> (which is contrary to conventions in boolean algebra).</li>
409
+ <li>In AND queries, we can only examine co-occurrences where adjacent clauses are separated by no more than {max_diff_tokens} tokens. This value can be adjusted within range [1, {MAX_DIFF_TOKENS}] in "Advanced options".</li>
410
+ <li>In AND queries, if a clause has more than {max_clause_freq} matches, we will estimate the count by examining a random subset of {max_clause_freq} documents out of all documents containing that clause. This value can be adjusted within range [1, {MAX_CLAUSE_FREQ}] in "Advanced options".</li>
411
+ <li>The above subsampling mechanism might cause a zero count on co-occurrences of some simple n-grams (e.g., <b>birds AND oil</b>).</li>
412
+ </ul>
413
+ <br>
414
+ <p style="font-size: 16px;">❗️WARNING: Corpus may contain problematic contents such as PII, toxicity, hate speech, and NSFW text. This tool is merely presenting selected text from the corpus, without any post-hoc safety filtering. It is NOT creating new text. This is a research prototype through which we can expose and examine existing problems with massive text corpora. Please use with caution. Don't be evil :)</p>
415
+ ''')
416
+ with gr.Row():
417
+ with gr.Column(scale=1):
418
+ search_docs_query = gr.Textbox(placeholder='Enter a query here', label='Query', interactive=True)
419
+ search_docs_maxnum = gr.Slider(minimum=1, maximum=MAXNUM, value=maxnum, step=1, label='Number of documents to display')
420
+ search_docs_max_disp_len = gr.Slider(minimum=1, maximum=MAX_DISP_LEN, value=max_disp_len, step=1, label='Number of tokens to display')
421
+ with gr.Accordion(label='Advanced options', open=False):
422
+ with gr.Row():
423
+ search_docs_max_clause_freq = gr.Slider(minimum=1, maximum=MAX_CLAUSE_FREQ, value=max_clause_freq, step=1, label='max_clause_freq')
424
+ search_docs_max_diff_tokens = gr.Slider(minimum=1, maximum=MAX_DIFF_TOKENS, value=max_diff_tokens, step=1, label='max_diff_tokens')
425
+ with gr.Row():
426
+ search_docs_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
427
+ search_docs_submit = gr.Button(value='Submit', variant='primary', visible=True)
428
+ search_docs_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
429
+ search_docs_tokenized = gr.Textbox(label='Tokenized', lines=1, interactive=False)
430
+ with gr.Column(scale=2):
431
+ search_docs_message = gr.Label(label='Message', num_top_classes=0)
432
+ search_docs_metadatas = []
433
+ search_docs_outputs = []
434
+ for i in range(MAXNUM):
435
+ with gr.Tab(label=str(i+1)):
436
+ search_docs_metadatas.append(gr.Textbox(label='Metadata', lines=3, interactive=False))
437
+ search_docs_outputs.append(gr.HighlightedText(label='Document', show_legend=False, color_map={"-": "red", "0": "green", "1": "cyan", "2": "blue", "3": "magenta"}))
438
+ search_docs_clear.add([search_docs_query, search_docs_latency, search_docs_tokenized, search_docs_message] + search_docs_metadatas + search_docs_outputs)
439
+ search_docs_submit.click(search_docs, inputs=[index_desc, search_docs_query, search_docs_maxnum, search_docs_max_disp_len, search_docs_max_clause_freq, search_docs_max_diff_tokens], outputs=[search_docs_latency, search_docs_tokenized, search_docs_message] + search_docs_metadatas + search_docs_outputs, api_name=False)
440
+
441
+ with gr.Tab('6. Search documents'):
442
+ with gr.Column():
443
+ gr.HTML(f'''<h2>6. Search for documents containing n-gram(s)</h2>''')
444
+ with gr.Accordion(label='Click to view instructions', open=False):
445
+ gr.HTML(f'''<p style="font-size: 16px;">This displays the documents in the corpus that satisfies your query. You can simply enter an n-gram, in which case the document displayed would contain your n-gram. You can also connect multiple n-gram terms with the AND/OR operators, in the <a href="https://en.wikipedia.org/wiki/Conjunctive_normal_form">CNF format</a>, in which case the displayed document contains n-grams such that it satisfies this logical constraint.</p>
446
+ <br>
447
+ <p style="font-size: 16px;">Example queries:</p>
448
+ <ul style="font-size: 16px;">
449
+ <li><b>natural language processing</b> (the displayed document would contain "natural language processing")</li>
450
+ <li><b>natural language processing AND deep learning</b> (the displayed document would contain both "natural language processing" and "deep learning")</li>
451
+ <li><b>natural language processing OR artificial intelligence AND deep learning OR machine learning</b> (the displayed document would contain at least one of "natural language processing" / "artificial intelligence", and also at least one of "deep learning" / "machine learning")</li>
452
+ </ul>
453
+ <br>
454
+ <p style="font-size: 16px;">Notes on CNF queries:</p>
455
+ <ul style="font-size: 16px;">
456
+ <li>A CNF query may contain up to {MAX_CLAUSES_PER_CNF} clauses, and each clause may contain up to {MAX_TERMS_PER_CLAUSE} n-gram terms.</li>
457
+ <li>When you write a query in CNF, note that <b>OR has higher precedence than AND</b> (which is contrary to conventions in boolean algebra).</li>
458
+ <li>In AND queries, we can only examine co-occurrences where adjacent clauses are separated by no more than {max_diff_tokens} tokens. This value can be adjusted within range [1, {MAX_DIFF_TOKENS}] in "Advanced options".</li>
459
+ <li>In AND queries, if a clause has more than {max_clause_freq} matches, we will estimate the count by examining a random subset of {max_clause_freq} occurrences of that clause. This value can be adjusted within range [1, {MAX_CLAUSE_FREQ}] in "Advanced options".</li>
460
+ <li>The above subsampling mechanism might cause a zero count on co-occurrences of some simple n-grams (e.g., <b>birds AND oil</b>).</li>
461
+ </ul>
462
+ <br>
463
+ <p style="font-size: 16px;">❗️WARNING: Corpus may contain problematic contents such as PII, toxicity, hate speech, and NSFW text. This tool is merely presenting selected text from the corpus, without any post-hoc safety filtering. It is NOT creating new text. This is a research prototype through which we can expose and examine existing problems with massive text corpora. Please use with caution. Don't be evil :)</p>
464
+ ''')
465
+ with gr.Row():
466
+ with gr.Column(scale=1):
467
+ search_docs_new_query = gr.Textbox(placeholder='Enter a query here', label='Query', interactive=True)
468
+ search_docs_new_max_disp_len = gr.Slider(minimum=1, maximum=MAX_DISP_LEN, value=max_disp_len, step=1, label='Number of tokens to display')
469
+ with gr.Accordion(label='Advanced options', open=False):
470
+ with gr.Row():
471
+ search_docs_new_max_clause_freq = gr.Slider(minimum=1, maximum=MAX_CLAUSE_FREQ, value=max_clause_freq, step=1, label='max_clause_freq')
472
+ search_docs_new_max_diff_tokens = gr.Slider(minimum=1, maximum=MAX_DIFF_TOKENS, value=max_diff_tokens, step=1, label='max_diff_tokens')
473
+ with gr.Row():
474
+ search_docs_new_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
475
+ search_docs_new_submit = gr.Button(value='Submit', variant='primary', visible=True)
476
+ search_docs_new_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
477
+ search_docs_new_tokenized = gr.Textbox(label='Tokenized', lines=1, interactive=False)
478
+ with gr.Column(scale=2):
479
+ search_docs_new_message = gr.Label(label='Message', num_top_classes=0)
480
+ search_docs_new_idx = gr.Slider(label='', minimum=0, maximum=0, step=1, value=0, interactive=False)
481
+ search_docs_new_metadata = gr.Textbox(label='Metadata', lines=3, max_lines=3, interactive=False)
482
+ search_docs_new_output = gr.HighlightedText(label='Document', show_legend=False, color_map={"-": "red", "0": "green", "1": "cyan", "2": "blue", "3": "magenta"})
483
+ search_docs_state = gr.State(value=None)
484
+ search_docs_new_clear.add([search_docs_new_query, search_docs_new_latency, search_docs_new_tokenized, search_docs_new_message, search_docs_new_idx, search_docs_new_metadata, search_docs_new_output])
485
+ search_docs_new_clear.click(
486
+ clear_search_docs_new,
487
+ inputs=[search_docs_state],
488
+ outputs=[search_docs_new_idx, search_docs_state]
489
+ )
490
+ search_docs_new_submit.click(
491
+ search_docs_new,
492
+ inputs=[index_desc, search_docs_new_query, search_docs_new_max_disp_len,
493
+ search_docs_new_max_clause_freq, search_docs_new_max_diff_tokens,
494
+ search_docs_state],
495
+ outputs=[search_docs_new_latency, search_docs_new_tokenized,
496
+ search_docs_new_message, search_docs_new_idx,
497
+ search_docs_new_metadata, search_docs_new_output,
498
+ search_docs_state]
499
+ )
500
+ search_docs_new_idx.input(
501
+ get_another_doc,
502
+ inputs=[index_desc, search_docs_new_idx, search_docs_new_max_disp_len,
503
+ search_docs_state],
504
+ outputs=[search_docs_new_metadata, search_docs_new_output]
505
+ )
506
 
507
+ with gr.Row():
508
+ gr.Markdown('''
509
+ If you find this tool useful, please kindly cite our paper:
510
+ ```bibtex
511
+ @article{Liu2024InfiniGram,
512
+ title={Infini-gram: Scaling Unbounded n-gram Language Models to a Trillion Tokens},
513
+ author={Liu, Jiacheng and Min, Sewon and Zettlemoyer, Luke and Choi, Yejin and Hajishirzi, Hannaneh},
514
+ journal={arXiv preprint arXiv:2401.17377},
515
+ year={2024}
516
+ }
517
+ ```
518
+ ''')
519
+
520
+ demo.queue(
521
+ default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT,
522
+ max_size=MAX_SIZE,
523
+ api_open=False,
524
+ ).launch(
525
+ max_threads=MAX_THREADS,
526
+ debug=DEBUG,
527
+ show_api=False,
528
+ )