File size: 5,199 Bytes
4b2c8d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import auto_schedule
import v_schedule

def greet(name, is_morning, temperature):
    salutation = "Good morning" if is_morning else "Good evening"
    greeting = f"{salutation} {name}. It is {temperature} degrees today"
    celsius = (temperature - 32) * 5 / 9
    return greeting, round(celsius, 2)

def percentage(x):
  return f"{x*100:.2f}%"

def get_schedule_time_and_image(result):
  result = [
    list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
  ]
  time = max(
    [
      max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result
    ]
  )
  return time, None

def calculate(p, m, f, b, w, c, mem):
  baseline_time=(f+b+w)*m + (f+b+w+c)*(p-1)
  baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
  baseline_acceleration=percentage(0)
  baseline_image=None

  
  zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
        cost_f=f,
        cost_b=b,
        cost_w=w,
        cost_comm=c,
        max_mem=mem * 2,
        print_scaling=1000
  ))
  zb_time,zb_image=get_schedule_time_and_image(zb_result)

  zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
  zb_acceleration=percentage(baseline_time/zb_time - 1)

  zbv_graph = v_schedule.PipelineGraph(
                n_stage=p,
                n_micro=m,
                f_cost=f/2,
                b_cost=b/2,
                w_cost=w/2,
                c_cost=c,
                f_mem=2,
                b_mem=-1,
                w_mem=-1,
                max_mem=mem * 4,
  )
  zbv_result = zbv_graph.get_v_schedule()

  zbv_time,zbv_image = get_schedule_time_and_image(zbv_result)
  zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
  zbv_acceleration=percentage(baseline_time/zbv_time - 1)
  zbv_image=None

  return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]

with gr.Blocks() as demo:
  gr.Markdown("Zero bubble pipeline parallel bubble calculator")
  with gr.Row():
    with gr.Column(scale=1):
      with gr.Group():
        gr.Markdown("Basic Parameters")
        with gr.Row():
          p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0)
          m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
    with gr.Column(scale=2):
      with gr.Group():
        gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
        with gr.Row():  
          f=gr.Number(label="Time of F", value=8, interactive=True, precision=0)
          b=gr.Number(label="Time of B", value=8, interactive=True, precision=0)
          w=gr.Number(label="Time of W", value=8, interactive=True, precision=0)
          c=gr.Number(label="Time of one P2P communication", value=1, interactive=True, precision=0)
  with gr.Group():
    gr.Markdown("Activation memory limit.")
    def update_mem(p, s, mem):
      print("update")
      if s=="custom":
        return mem
      return p*int(s[:-1])
    memsel=gr.Radio(choices=["1p", "2p", "3p", "custom"], value="1p")
    mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
    memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
    p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
    
  button=gr.Button("Calculate")

  with gr.Group():
    gr.Markdown("1F1B")
    with gr.Row():
      with gr.Column(scale=1):
        baseline_time=gr.Textbox("", label="Longest Stage Time")
        baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
        baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
      with gr.Column(scale=4):
        baseline_image=gr.Image(None, interactive=False, label="Schedule Image")
  
  with gr.Group():
    gr.Markdown("Zero Bubble Schedule")
    with gr.Row():
      with gr.Column(scale=1):
        zb_time=gr.Textbox("", label="Longest Stage Time")
        zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
        zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
      with gr.Column(scale=4):
        zb_image=gr.Image(None, interactive=False, label="Schedule Image")
  with gr.Group():
    gr.Markdown("Zero Bubble V Schedule")
    with gr.Row():
      with gr.Column(scale=1):
        zbv_time=gr.Textbox("", label="Longest Stage Time")
        zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
        zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
      with gr.Column(scale=4):
        zbv_image=gr.Image(None, interactive=False, label="Schedule Image")
    button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
demo.launch()