File size: 4,746 Bytes
331412c 039cd66 331412c 9c7dc56 331412c 039cd66 e64e782 9c7dc56 e64e782 a109a1e e64e782 9c7dc56 039cd66 e64e782 039cd66 e64e782 9c7dc56 331412c 9c7dc56 331412c 9c7dc56 039cd66 9c7dc56 a109a1e 039cd66 331412c bbd4f95 331412c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import torch
EXAMPLE_MD = """
```python
import torch
t1 = torch.arange({n1}).view({dim1})
t2 = torch.arange({n2}).view({dim2})
(t1 @ t2).shape = {out_shape}
```
"""
def generate_example(dim1: list, dim2: list):
n1 = 1
n2 = 1
for i in dim1:
n1 *= i
for i in dim2:
n2 *= i
t1 = torch.arange(n1).view(dim1)
t2 = torch.arange(n2).view(dim2)
try:
out_shape = list((t1 @ t2).shape)
except RuntimeError:
out_shape = "error"
code = EXAMPLE_MD.format(
n1=str(n1), dim1=str(dim1), n2=str(n2), dim2=str(dim2), out_shape=str(out_shape)
)
return dim1, dim2, code
def sanitize_dimention(dim):
if dim is None:
gr.Error("one of the dimentions is empty, please fill it")
if "[" in dim:
dim = dim.replace("[", "")
if "]" in dim:
dim = dim.replace("]", "")
if "," in dim:
dim = dim.replace(",", " ").strip()
out = [int(i.strip()) for i in dim.split()]
else:
out = [int(dim.strip())]
if 0 in out:
gr.Error(
"Found the number 0 in one of the dimensions which is not allowed, consider using 1 instead"
)
return out
def create_row(dim,is_dim=None,checks=None):
out = "| "
n_dim = len(dim)
for i in range(n_dim):
if (is_dim ==1 and i != n_dim-1) or (is_dim ==2 and i ==n_dim-1):
color = "green" if checks[i] == "V" else "red"
out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
elif (is_dim ==1 and i == n_dim-1) or (is_dim ==2 and i ==n_dim-2):
color = "blue" if checks[i] == "V" else "yellow"
out += f"<strong style='color: {color}'> {dim[i]} </strong>| "
else :
out+= f"{dim[i]} | "
return out + "\n"
def create_header(n_dim, checks=None):
checks = ["<!-- -->"] * n_dim if checks is None else checks
out = "| "
for i in checks:
out = out + i + " | "
out += "\n" + "|---" * n_dim + "|\n"
return out
def generate_table(dim1, dim2, checks=None):
n_dim = len(dim1)
table = create_header(n_dim, checks)
# tensor 1
if not checks :
table += create_row(dim1)
else :
table += create_row(dim1,1,checks)
# tensor 2
if not checks :
table += create_row(dim2)
else :
table += create_row(dim2,2,checks)
return table
def alignment_and_fill_with_ones(dim1, dim2):
n_dim = max(len(dim1), len(dim2))
if len(dim1) == len(dim2):
pass
elif len(dim1) < len(dim2):
placeholder = [1] * (n_dim - len(dim1))
placeholder.extend(dim1)
dim1 = placeholder
else:
placeholder = [1] * (n_dim - len(dim2))
placeholder.extend(dim2)
dim2 = placeholder
return dim1, dim2
def check_validity(dim1,dim2):
if len(dim1) < 2:
return ["WIP"] * len(dim1)
out = []
for i in range(len(dim1)-2):
if dim1[i] == dim2[i]:
out.append("V")
else :
out.append("X")
# final dims
if dim1[-1] == dim2[-2]:
out.extend(["V","V"])
else :
out.extend(["X","X"])
return out
def substitute_ones_with_concat(dim1,dim2):
for i in range(len(dim1)-2):
dim1[i] = dim2[i] if dim1[i] == 1 else dim1[i]
dim2[i] = dim1[i] if dim2[i] == 1 else dim2[i]
return dim1, dim2
def predict(dim1, dim2):
dim1 = sanitize_dimention(dim1)
dim2 = sanitize_dimention(dim2)
dim1, dim2, code = generate_example(dim1, dim2)
# TODO
# fix for dims if one or both have dimensions is 1
# Table 1
dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2)
table1 = generate_table(dim1, dim2)
# Table 2
dim1, dim2 = substitute_ones_with_concat(dim1,dim2)
table2 = generate_table(dim1, dim2)
# Table 3
checks = check_validity(dim1,dim2)
table3 = generate_table(dim1,dim2,checks)
out = code
out += "\n# Step1 (alignment and pre_append with ones)\n" + table1
out += "\n# Step2 (susbtitute columns that have 1 with concat)\nexcept for last 2 dimensions\n" + table2
out += "\n# Step3 (check if matrix multiplication is valid)\n"
out += "* last dimension of dim1 should equal before last dimension of dim2 (blue or yellow colors)\n"
out += "* all the other dimensions should be equal to one another (green or red colors)\n\n" + table3
if "X" not in checks :
dim1[-1] = dim2[-1]
out += f"\n# Final dimension\n `output.shape = {dim1}`"
return out
demo = gr.Interface(
predict,
inputs=["text", "text"],
outputs=["markdown"],
examples=[["9,2,1,3,3", "5,3,7"], ["1,2,3", "5,2,7"]],
)
demo.launch(debug=True)
|