lm-similarity / src /test.py
Joschka Strueber
[Fix] plotly heatmap
3b16cfa
raw
history blame
436 Bytes
import plotly.graph_objects as go
import numpy as np
models = ["model1", "model2", "model3"]
size = len(models)
sim = np.random.rand(size, size)
sim = (sim + sim.T) / 2
sim = np.round(sim, 2)
fig = go.Figure(data=go.Heatmap(z=sim, x=models, y=models, colorscale="Viridis"))
fig.update_layout(title="Test Heatmap", xaxis_title="Models", yaxis_title="Models", width=800, height=800)
fig.show()
# Save fig
fig.write_html("heatmap.html")