yonigozlan HF staff commited on
Commit
03cff7d
·
1 Parent(s): a03a4af

change compile options

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -20,19 +20,26 @@ model = AutoModelForObjectDetection.from_pretrained(
20
  disable_custom_kernels=True,
21
  torch_dtype=torch.float16,
22
  ).to(device)
23
- model_compiled = torch.compile(model, mode="reduce-overhead")
 
 
 
24
 
25
  url = "http://images.cocodataset.org/val2017/000000039769.jpg"
26
  image = Image.open(requests.get(url, stream=True).raw)
27
  inputs = processor(images=image, return_tensors="pt").to(device).to(torch.float16)
28
 
29
- print("Compiling model...")
30
- start_time = time.time()
31
- with torch.no_grad():
32
- for _ in range(10):
33
- outputs = model_compiled(**inputs)
34
- _ = outputs[0].cpu()
35
- print(f"Model compiled in {time.time() - start_time:.2f} seconds.")
 
 
 
 
36
 
37
  css = """
38
  .feedback textarea {font-size: 24px !important}
@@ -175,4 +182,5 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
175
  )
176
 
177
  if __name__ == "__main__":
 
178
  demo.launch(show_error=True)
 
20
  disable_custom_kernels=True,
21
  torch_dtype=torch.float16,
22
  ).to(device)
23
+ model_compiled = torch.compile(
24
+ model,
25
+ options={"triton.cudagraphs": True, "max_autotune": True},
26
+ )
27
 
28
  url = "http://images.cocodataset.org/val2017/000000039769.jpg"
29
  image = Image.open(requests.get(url, stream=True).raw)
30
  inputs = processor(images=image, return_tensors="pt").to(device).to(torch.float16)
31
 
32
+
33
+ @spaces.GPU
34
+ def init_compiled_model():
35
+ print("Compiling model...")
36
+ start_time = time.time()
37
+ with torch.no_grad():
38
+ for _ in range(10):
39
+ outputs = model_compiled(**inputs)
40
+ _ = outputs[0].cpu()
41
+ print(f"Model compiled in {time.time() - start_time:.2f} seconds.")
42
+
43
 
44
  css = """
45
  .feedback textarea {font-size: 24px !important}
 
182
  )
183
 
184
  if __name__ == "__main__":
185
+ init_compiled_model()
186
  demo.launch(show_error=True)