Spaces:
Running
Running
"""Graph analytics operations.""" | |
import enum | |
import os | |
import fsspec | |
from lynxkite.core import ops | |
from collections import deque | |
from . import core | |
import grandcypher | |
import joblib | |
import matplotlib | |
import networkx as nx | |
import pandas as pd | |
import polars as pl | |
import json | |
mem = joblib.Memory(".joblib-cache") | |
op = ops.op_registration(core.ENV) | |
class FileFormat(enum.StrEnum): | |
csv = "csv" | |
parquet = "parquet" | |
json = "json" | |
excel = "excel" | |
def import_file( | |
*, file_path: str, table_name: str, file_format: FileFormat, **kwargs | |
) -> core.Bundle: | |
"""Read the contents of the a file into a `Bundle`. | |
Args: | |
file_path: Path to the file to import. | |
table_name: Name to use for identifying the table in the bundle. | |
file_format: Format of the file. Has to be one of the values in the `FileFormat` enum. | |
Returns: | |
Bundle: Bundle with a single table with the contents of the file. | |
""" | |
if file_format == "csv": | |
names = kwargs.pop("columns", "<from file>") | |
names = ( | |
pd.api.extensions.no_default if names == "<from file>" else names.split(",") | |
) | |
sep = kwargs.pop("separator", "<auto>") | |
sep = pd.api.extensions.no_default if sep == "<auto>" else sep | |
df = pd.read_csv(file_path, names=names, sep=sep, **kwargs) | |
elif file_format == "json": | |
df = pd.read_json(file_path, **kwargs) | |
elif file_format == "parquet": | |
df = pd.read_parquet(file_path, **kwargs) | |
elif file_format == "excel": | |
df = pd.read_excel(file_path, **kwargs) | |
else: | |
df = ValueError(f"Unsupported file format: {file_format}") | |
return core.Bundle(dfs={table_name: df}) | |
def import_parquet(*, filename: str): | |
"""Imports a Parquet file.""" | |
return pd.read_parquet(filename) | |
def import_csv( | |
*, filename: str, columns: str = "<from file>", separator: str = "<auto>" | |
): | |
"""Imports a CSV file.""" | |
return pd.read_csv( | |
filename, | |
names=pd.api.extensions.no_default | |
if columns == "<from file>" | |
else columns.split(","), | |
sep=pd.api.extensions.no_default if separator == "<auto>" else separator, | |
) | |
def import_graphml(*, filename: str): | |
"""Imports a GraphML file.""" | |
files = fsspec.open_files(filename, compression="infer") | |
for f in files: | |
if ".graphml" in f.path: | |
with f as f: | |
return nx.read_graphml(f) | |
raise ValueError(f"No .graphml file found at {filename}") | |
def import_osm(*, location: str): | |
import osmnx as ox | |
return ox.graph.graph_from_place(location, network_type="drive") | |
def discard_loop_edges(graph: nx.Graph): | |
graph = graph.copy() | |
graph.remove_edges_from(nx.selfloop_edges(graph)) | |
return graph | |
def discard_parallel_edges(graph: nx.Graph): | |
return nx.DiGraph(graph) | |
def sql(bundle: core.Bundle, *, query: ops.LongStr, save_as: str = "result"): | |
"""Run a SQL query on the DataFrames in the bundle. Save the results as a new DataFrame.""" | |
bundle = bundle.copy() | |
if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true": | |
with pl.Config() as cfg: | |
cfg.set_verbose(True) | |
res = ( | |
pl.SQLContext(bundle.dfs) | |
.execute(query) | |
.collect(engine="gpu") | |
.to_pandas() | |
) | |
# TODO: Currently `collect()` moves the data from cuDF to Polars. Then we convert it to Pandas, | |
# which (hopefully) puts it back into cuDF. Hopefully we will be able to keep it in cuDF. | |
else: | |
res = pl.SQLContext(bundle.dfs).execute(query).collect().to_pandas() | |
bundle.dfs[save_as] = res | |
return bundle | |
def cypher(bundle: core.Bundle, *, query: ops.LongStr, save_as: str = "result"): | |
"""Run a Cypher query on the graph in the bundle. Save the results as a new DataFrame.""" | |
bundle = bundle.copy() | |
graph = bundle.to_nx() | |
res = grandcypher.GrandCypher(graph).run(query) | |
bundle.dfs[save_as] = pd.DataFrame(res) | |
return bundle | |
def organize(bundle: list[core.Bundle], *, code: ops.LongStr) -> core.Bundle: | |
"""Lets you rename/copy/delete DataFrames, and modify relations. | |
TODO: Merge this with "Create graph". | |
""" | |
bundle = bundle.copy() | |
exec(code, globals(), {"bundle": bundle}) | |
return bundle | |
def sample_graph(graph: nx.Graph, *, nodes: int = 100): | |
"""Takes a (preferably connected) subgraph.""" | |
sample = set() | |
to_expand = deque([next(graph.nodes.keys().__iter__())]) | |
while to_expand and len(sample) < nodes: | |
node = to_expand.pop() | |
for n in graph.neighbors(node): | |
if n not in sample: | |
sample.add(n) | |
to_expand.append(n) | |
if len(sample) == nodes: | |
break | |
return nx.Graph(graph.subgraph(sample)) | |
def _map_color(value): | |
if pd.api.types.is_numeric_dtype(value): | |
cmap = matplotlib.cm.get_cmap("viridis") | |
value = (value - value.min()) / (value.max() - value.min()) | |
rgba = cmap(value.values) | |
return [ | |
"#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255)) | |
for r, g, b in rgba[:, :3] | |
] | |
else: | |
cmap = matplotlib.cm.get_cmap("Paired") | |
categories = pd.Index(value.unique()) | |
colors = cmap.colors[: len(categories)] | |
return [ | |
"#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255)) | |
for r, g, b in [ | |
colors[min(len(colors) - 1, categories.get_loc(v))] for v in value | |
] | |
] | |
def visualize_graph( | |
graph: core.Bundle, | |
*, | |
color_nodes_by: ops.NodeAttribute = None, | |
label_by: ops.NodeAttribute = None, | |
color_edges_by: ops.EdgeAttribute = None, | |
): | |
nodes = core.df_for_frontend(graph.dfs["nodes"], 10_000) | |
if color_nodes_by: | |
nodes["color"] = _map_color(nodes[color_nodes_by]) | |
for cols in ["x y", "long lat"]: | |
x, y = cols.split() | |
if ( | |
x in nodes.columns | |
and nodes[x].dtype == "float64" | |
and y in nodes.columns | |
and nodes[y].dtype == "float64" | |
): | |
cx, cy = nodes[x].mean(), nodes[y].mean() | |
dx, dy = nodes[x].std(), nodes[y].std() | |
# Scale up to avoid float precision issues and because eCharts omits short edges. | |
scale_x = 100 / max(dx, dy) | |
scale_y = scale_x | |
if y == "lat": | |
scale_y *= -1 | |
pos = { | |
node_id: ((row[x] - cx) * scale_x, (row[y] - cy) * scale_y) | |
for node_id, row in nodes.iterrows() | |
} | |
curveness = 0 # Street maps are better with straight streets. | |
break | |
else: | |
pos = nx.spring_layout( | |
graph.to_nx(), iterations=max(1, int(10000 / len(nodes))) | |
) | |
curveness = 0.3 | |
nodes = nodes.to_records() | |
edges = core.df_for_frontend( | |
graph.dfs["edges"].drop_duplicates(["source", "target"]), 10_000 | |
) | |
if color_edges_by: | |
edges["color"] = _map_color(edges[color_edges_by]) | |
edges = edges.to_records() | |
v = { | |
"animationDuration": 500, | |
"animationEasingUpdate": "quinticInOut", | |
"tooltip": {"show": True}, | |
"series": [ | |
{ | |
"type": "graph", | |
# Mouse zoom/panning is disabled for now. It interacts badly with ReactFlow. | |
# "roam": True, | |
"lineStyle": { | |
"color": "gray", | |
"curveness": curveness, | |
}, | |
"emphasis": { | |
"focus": "adjacency", | |
"lineStyle": { | |
"width": 10, | |
}, | |
}, | |
"label": {"position": "top", "formatter": "{b}"}, | |
"data": [ | |
{ | |
"id": str(n.id), | |
"x": float(pos[n.id][0]), | |
"y": float(pos[n.id][1]), | |
# Adjust node size to cover the same area no matter how many nodes there are. | |
"symbolSize": 50 / len(nodes) ** 0.5, | |
"itemStyle": {"color": n.color} if color_nodes_by else {}, | |
"label": {"show": label_by is not None}, | |
"name": str(getattr(n, label_by, "")) if label_by else None, | |
"value": str(getattr(n, color_nodes_by, "")) | |
if color_nodes_by | |
else None, | |
} | |
for n in nodes | |
], | |
"links": [ | |
{ | |
"source": str(r.source), | |
"target": str(r.target), | |
"lineStyle": {"color": r.color} if color_edges_by else {}, | |
"value": str(getattr(r, color_edges_by, "")) | |
if color_edges_by | |
else None, | |
} | |
for r in edges | |
], | |
}, | |
], | |
} | |
return v | |
def view_tables(bundle: core.Bundle, *, limit: int = 100): | |
return bundle.to_dict(limit=limit) | |
def create_graph(bundle: core.Bundle, *, relations: str = None) -> core.Bundle: | |
"""Replace relations of the given bundle | |
relations is a stringified JSON, instead of a dict, because complex Yjs types (arrays, maps) | |
are not currently supported in the UI. | |
Args: | |
bundle: Bundle to modify | |
relations (str, optional): Set of relations to set for the bundle. The parameter | |
should be a JSON object where the keys are relation names and the values are | |
a dictionary representation of a `RelationDefinition`. | |
Defaults to None. | |
Returns: | |
Bundle: The input bundle with the new relations set. | |
""" | |
bundle = bundle.copy() | |
if not (relations is None or relations.strip() == ""): | |
bundle.relations = [ | |
core.RelationDefinition(**r) for r in json.loads(relations).values() | |
] | |
return ops.Result(output=bundle, display=bundle.to_dict(limit=100)) | |
def define_model(*, model_workspace: str, save_as: str = "model"): | |
"""Reads a PyTorch model workspace and returns it as a model in a bundle.""" | |
return None | |
def biomedical_foundation_graph(*, filter_nodes: str): | |
"""Loads the gigantic Lynx-maintained knowledge graph. Includes drugs, diseases, genes, proteins, etc.""" | |
return None | |
def train_model( | |
bundle: core.Bundle, *, model_name: str, model_mapping: str, epochs: int = 1 | |
): | |
"""Trains the selected model on the selected dataset. Most training parameters are set in the model definition.""" | |
return None | |
def model_inference( | |
bundle: core.Bundle, | |
*, | |
model_name: str, | |
model_mapping: str, | |
save_output_as: str = "prediction", | |
): | |
"""Executes a trained model.""" | |
return None | |
def train_test_split(bundle: core.Bundle, *, table_name: str, test_ratio: float = 0.1): | |
"""Splits a dataframe in the bundle into separate "_train" and "_test" dataframes.""" | |
return None | |