Spaces:
Runtime error
Runtime error
| from typing import Callable, Dict, List, Optional, Tuple | |
| import networkx as nx | |
| import numpy as np | |
| import torch | |
| def generate_rand_int_excluding(rng: np.random.RandomState, max: int, exclude: int) -> int: | |
| """Random integer generator, excluding a specific number | |
| Args: | |
| rng: Numpy random number generator | |
| max: Max number | |
| exclude: Number to exclude | |
| Returns: | |
| Random integer in [0, max], excluding the `exclude` integer. | |
| """ | |
| while True: | |
| # Create the random integer | |
| x = rng.randint(max) | |
| # Return the random integer if it isn't the exclude value, otherwise try | |
| # again | |
| if x != exclude: | |
| return x | |
| def generate_random_walks( # noqa: max-complexity | |
| n_nodes: int = 21, | |
| max_length: int = 10, | |
| n_walks: int = 1000, | |
| p_edge: float = 0.1, | |
| seed: int = 1002, | |
| gpt2_tokenizer: bool = False, | |
| ) -> Tuple[Callable[[List[str]], Dict[str, List[float]]], List[str], List[str], torch.Tensor,]: | |
| """Generate random walks | |
| Args: | |
| n_nodes: Number of nodes. This should not be more than 26, as we use | |
| single letters to represent each node. | |
| max_length: Maximum number of steps in each random walk | |
| n_walks: Number of random walks (samples) to create | |
| p_edge: Probability that any source node connects to any other | |
| destination node | |
| seed: Random seed | |
| gpt2_tokenizer: True if GPT2's tokenizer is being used | |
| Returns: | |
| Tuple of metric function, | |
| """ | |
| # Initialise a random state with the seed | |
| rng = np.random.RandomState(seed) | |
| # Create the adjacency matrix | |
| # https://en.wikipedia.org/wiki/Adjacency_matrix | |
| # This is a 2d matrix, where the rows represent the source nodes and the | |
| # columns represent the destination nodes. If a cell (i,j) is True, then | |
| # there is a directional edge from the source node (i) to the destination | |
| # node (j). If it is false there is no connection. | |
| while True: | |
| # Create the adjacency matrix, where each node is connected to each | |
| # other node, with probability p_edge | |
| adjacency_matrix: np.ndarray = rng.rand(n_nodes, n_nodes) > (1 - p_edge) | |
| # Nodes can't be connected to themselves, so the diagonal values must | |
| # all be False | |
| np.fill_diagonal(adjacency_matrix, 0) | |
| # Each destination node (column) must be connected to at least one | |
| # source node. This checks if this is the case, by checking there is a | |
| # True value in every column. If it is not the case, we try to generate | |
| # a new adjacency matrix again from scratch (in the while loop). | |
| if np.all(adjacency_matrix.sum(1)): | |
| break | |
| # Set the goal node as 0 | |
| goal: int = 0 | |
| # The goal node is the terminal state, so we make sure that it doesn't | |
| # have a directional edge going to any other nodes (i.e. it can only be | |
| # connected to from previous nodes). We also set the connection to itself as | |
| # True. | |
| adjacency_matrix[goal, :] = 0 | |
| adjacency_matrix[goal, goal] = 1 | |
| # Create dicts for converting nodes into characters and vice versa | |
| # Nodes are converted into characters as these (when split by the delimiter) are | |
| # guaranteed to be tokenized as individual tokens. | |
| char_to_node: Dict[str, int] = {chr(ix + ord("a")): ix for ix in range(n_nodes)} | |
| node_to_char: Dict[int, str] = {ix: chr(ix + ord("a")) for ix in range(n_nodes)} | |
| # Initialise a list of sample walks | |
| sample_walks: List[str] = [] | |
| # String delimiter (to force the tokenizer to keep all nodes as separate | |
| # tokens) | |
| delimiter: str = "|" if gpt2_tokenizer else "" | |
| # Create n_walks samples | |
| for _ in range(n_walks): | |
| # Create a random starting node (that isn't already at the goal state) | |
| node: int = generate_rand_int_excluding(rng, n_nodes, goal) | |
| # Initialise the list of nodes that we visit | |
| walk_nodes: List[int] = [node] | |
| # Do a series of steps, until we hit the maximum number of steps or the | |
| # goal state (whichever comes first) | |
| for _step in range(max_length - 1): | |
| # From the starting node, get all the nodes we can move to. Pick one | |
| # of these at random, and add it to the list of visited nodes | |
| node = rng.choice(np.nonzero(adjacency_matrix[node])[0]) | |
| walk_nodes.append(node) | |
| # If we're at the goal state, stop | |
| if node == goal: | |
| break | |
| # Convert the nodes visited to letters (not integers) | |
| walk: List[str] = [node_to_char[ix] for ix in walk_nodes] | |
| # Concatenate into a journey, with each node letter separated by the | |
| # delimiter. | |
| sample_walks.append(delimiter.join(walk)) | |
| # Initialise list of shortest lengths for each node (to the goal node) | |
| shortest_lengths: List[int] = [] | |
| # Create a directional graph from the adjacency list | |
| directional_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph) | |
| # Fore each node (except for the goal node), find the shortest path | |
| for start in set(range(n_nodes)) - {goal}: | |
| try: | |
| # Find the shortest path (up to the max_length) | |
| shortest_path = nx.shortest_path(directional_graph, start, goal)[:max_length] | |
| shortest_lengths.append(len(shortest_path)) | |
| except Exception: | |
| # If there is no path, use the maximum length instead | |
| shortest_lengths.append(max_length) | |
| def metric_fn( | |
| samples: List[str], | |
| ) -> Dict[str, List[float]]: | |
| """Metric Function | |
| Args: | |
| samples: Batch of samples | |
| Returns: | |
| Dict of metrics, each with a key of the metric name and value as a | |
| list of metric values for each batch item. | |
| """ | |
| # Length to set if the path is invalid | |
| invalid_path_length: int = 100 | |
| # Initialise batch lengths & reference lengths (the optimal length | |
| # starting from each batch items specific start node) | |
| lengths: List[float] = [] | |
| sample_optimal_lengths: List[int] = [] | |
| for sample_str in samples: | |
| # Remove GPT2 specific tokenizer delimiter | |
| if gpt2_tokenizer: | |
| sample_str = sample_str.replace("|", "") | |
| # Convert the sample into a list of nodes (default to an unused | |
| # integer if the node is not found) | |
| sample: List[int] = [char_to_node.get(c, 1000) for c in sample_str] | |
| # Initialise the specific sample length | |
| length: Optional[float] = None | |
| for node in range(len(sample)): | |
| # If an invalid path is taken, set the length to the invalid | |
| # path score | |
| if sample[node] >= n_nodes or node > 0 and not adjacency_matrix[sample[node - 1], sample[node]]: | |
| length = invalid_path_length | |
| break | |
| # Otherwise increment the length for each move (where we don't | |
| # end up at the goal node) | |
| elif sample[node] == 0: | |
| length = node + 1 | |
| break | |
| # Catch the case where there are no moves | |
| if length is None: | |
| length = invalid_path_length | |
| # Store the batch item length & optimal length staring from the | |
| # start node | |
| lengths.append(float(length)) | |
| sample_optimal_lengths.append(shortest_lengths[sample[0] - 1]) | |
| # Calculate optimality scores, in [0, 1], as compared to the shortest | |
| # path | |
| lengths_tensor = torch.tensor(lengths, dtype=torch.float) | |
| bound_lengths: torch.Tensor = torch.where( | |
| lengths_tensor.eq(invalid_path_length), max_length, lengths_tensor | |
| ).abs() | |
| optimal_lengths = torch.as_tensor(sample_optimal_lengths) | |
| # Optimality scores, in [0, 1], as compared to the shortest path | |
| optimality = (max_length - bound_lengths) / (max_length - optimal_lengths) | |
| return { | |
| "lengths": lengths, | |
| "optimality": optimality.tolist(), | |
| } | |
| logit_mask = torch.tensor(adjacency_matrix) | |
| # Set the evaluation prompts as a list of unique random walk samples, using | |
| # just the start point (first character) from each samples. | |
| eval_prompts = list(sorted(set(w[0] for w in sample_walks))) | |
| eval_prompts = [prompt + delimiter for prompt in eval_prompts] | |
| return (metric_fn, eval_prompts, sample_walks, logit_mask) | |