diff --git a/docs/conf.py b/docs/conf.py index 4546272..c926e13 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,7 +83,7 @@ # These names will also be recognized as section headers in Google-style docstrings, # in addition to the default ones like "Args", "Returns", etc. -napoleon_custom_sections = ["Overview", "Graph Attributes", "List of Available Datasets", "Splits", "Usage Notes"] +napoleon_custom_sections = ["Overview", "Helpers", "Graph Attributes", "List of Available Datasets", "Splits", "Usage Notes"] # -- MathJax configuration --------------------------------------------------- diff --git a/graphbench/__init__.py b/graphbench/__init__.py index d1990c6..79a94b1 100644 --- a/graphbench/__init__.py +++ b/graphbench/__init__.py @@ -1,6 +1,7 @@ from ._loader import Loader from ._evaluator import Evaluator #from ._optimize import Optimizer +from . import helpers -__all__ = ["datasets", "Loader", "Evaluator"] +__all__ = ["datasets", "helpers", "Loader", "Evaluator"] diff --git a/graphbench/datasets/_combinatorial_optimization.py b/graphbench/datasets/_combinatorial_optimization.py index 510d1ea..2ba5e14 100644 --- a/graphbench/datasets/_combinatorial_optimization.py +++ b/graphbench/datasets/_combinatorial_optimization.py @@ -94,6 +94,124 @@ class CODataset(GraphDataset): Please refer to the `GraphBench paper `__ for the exact parameters used for graph generation. + Helpers: + The following helper functions are available under ``graphbench.helpers``. + They are optional but reduce boilerplate for unsupervised CO training and + evaluation. Each loss refers to its matching decoder and metric. + + **Losses** + + - :func:`graphbench.helpers.mis_loss` - Unsupervised loss function to + train a model for MIS. + + At test time, use :func:`graphbench.helpers.mis_decoder` to convert the + model's soft output to a discrete solution, and + :func:`graphbench.helpers.mis_size` to evaluate the model's performance. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes]``. + - ``batch`` (Batch): PyG batch with the input graphs. + - ``beta`` (float, optional): Edge penalty weight, default to 1.0. + + - :func:`graphbench.helpers.max_cut_loss` - Unsupervised loss function to + train a model for max-cut. + + At test time, use :func:`graphbench.helpers.max_cut_size` to evaluate the + model's performance. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes]``. + - ``batch`` (Batch): PyG batch with the input graphs. + + - :func:`graphbench.helpers.graph_coloring_loss` - Unsupervised loss + function to train a model for graph coloring. + + At test time, use :func:`graphbench.helpers.graph_coloring_decoder` to + convert the model's soft output to a discrete solution, and + :func:`graphbench.helpers.num_colors_used` to evaluate the model's + performance. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes, num_colors]``. + - ``batch`` (Batch): PyG batch with the input graphs. + + **Decoders** + + - :func:`graphbench.helpers.mis_decoder` - Converts the model's soft + prediction to a discrete solution to the MIS problem. + + This can be used at test time for models trained with + :func:`graphbench.helpers.mis_loss`. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes]``. + - ``batch`` (Batch): PyG batch with the input graphs. + - ``dec_length`` (int, optional): Number of decoding steps, default to 300. + - ``num_seeds`` (int, optional): Number of decoding restarts, default to 1. + + - :func:`graphbench.helpers.graph_coloring_decoder` - Converts the model's + soft prediction to a discrete solution to the graph coloring problem. + + This can be used at test time for models trained with + :func:`graphbench.helpers.graph_coloring_loss`. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes, num_colors]``. + - ``batch`` (Batch): PyG batch with the input graphs. + - ``num_seeds`` (int, optional): Number of decoding restarts, default to 1. + + **Metrics** + + - :func:`graphbench.helpers.mis_size` - Computes MIS size from a decoded + solution. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes]``. + - ``batch`` (Batch): PyG batch with the input graphs. + - ``dec_length`` (int, optional): Number of decoding steps, default to 300. + - ``num_seeds`` (int, optional): Number of decoding restarts, default to 1. + + - :func:`graphbench.helpers.max_cut_size` - Computes max-cut size from a + thresholded cut assignment. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes]``. + - ``batch`` (Batch): PyG batch with the input graphs. + + - :func:`graphbench.helpers.num_colors_used` - Computes the number of + colors used by a decoded coloring. Uses + :func:`graphbench.helpers.graph_coloring_decoder` internally. + + Parameters: + - ``x`` (Tensor): Soft model output of shape ``[num_nodes, num_colors]``. + - ``batch`` (Batch): PyG batch with the input graphs. + - ``num_seeds`` (int, optional): Number of decoding restarts, default to 1. + + **Validators** + + - :func:`graphbench.helpers.validate_mis_solution` - Checks whether a + the given solution is a valid independent set for the provided graph. + + Parameters: + - ``graph`` (Data): The problem graph. + - ``solution`` (Tensor): The independent set of shape ``[num_nodes]``, as a binary vector where a 1 indicates that the node is in the set. + + - :func:`graphbench.helpers.validate_max_cut_solution` - Always returns + ``True`` because any partition defines a valid cut. + + Parameters: + - ``graph`` (Data): The problem graph. + - ``solution`` (Tensor): Binary node indicators. + + - :func:`graphbench.helpers.validate_chrom_solution` - Checks whether the + given solution is a valid graph coloring for the provided graph. + + Parameters: + - ``graph`` (Data): The problem graph. + - ``solution`` (Tensor): The graph coloring of shape ``[num_nodes]``, as a vector where each entry indicates the color assigned to the corresponding + node. + + Splits: All datasets use a 70% / 15% / 15% split for training, validation, and testing. diff --git a/graphbench/helpers/__init__.py b/graphbench/helpers/__init__.py new file mode 100644 index 0000000..35a4afb --- /dev/null +++ b/graphbench/helpers/__init__.py @@ -0,0 +1,50 @@ +from graphbench._helpers import ( + download_and_unpack, + get_logger, + split_dataset, + SourceSpec, + VectorizedCircuitSimulator, +) +from .decoders import ( + graph_coloring_decoder, + mis_decoder, + mis_size, + max_cut_size, + num_colors_used, + UNSUPERVISED_CO_METRICS, + UNSUPERVISED_CO_METRIC_NAMES, +) +from .unsupervised_losses import ( + graph_coloring_loss, + max_cut_loss, + mis_loss, + UNSUPERVISED_CO_LOSSES, +) +from .validate_solution import ( + validate_chrom_solution, + validate_max_cut_solution, + validate_mis_solution, +) + + +__all__ = [ + "download_and_unpack", + "get_logger", + "split_dataset", + "SourceSpec", + "VectorizedCircuitSimulator", + "graph_coloring_decoder", + "mis_decoder", + "mis_size", + "max_cut_size", + "num_colors_used", + "UNSUPERVISED_CO_METRICS", + "UNSUPERVISED_CO_METRIC_NAMES", + "graph_coloring_loss", + "max_cut_loss", + "mis_loss", + "UNSUPERVISED_CO_LOSSES", + "validate_chrom_solution", + "validate_max_cut_solution", + "validate_mis_solution", +] \ No newline at end of file diff --git a/graphbench/helpers/decoders.py b/graphbench/helpers/decoders.py new file mode 100644 index 0000000..29827f5 --- /dev/null +++ b/graphbench/helpers/decoders.py @@ -0,0 +1,131 @@ +# Source for mis and max cut: https://github.com/WenkelF/copt/blob/main/utils/metrics.py + +import torch +from torch import Tensor +from torch_geometric.data import Batch +from torch_geometric.utils import unbatch, unbatch_edge_index, remove_self_loops + + +def mis_size(x: Tensor, batch: Batch, dec_length: int = 300, num_seeds: int = 1) -> Tensor: + batch = mis_decoder(x, batch, dec_length, num_seeds) + + data_list = batch.to_data_list() + + size_list = [data.is_size for data in data_list] + + return Tensor(size_list).mean() + + +def mis_decoder(x: Tensor, batch: Batch, dec_length: int = 300, num_seeds: int = 1) -> Batch: + x = torch.sigmoid(x) + data_list = batch.to_data_list() + x_list = unbatch(x, batch.batch) + + for data, x_data in zip(data_list, x_list): + is_size_list = [] + + for seed in range(num_seeds): + + order = torch.argsort(x_data, dim=0, descending=True) + c = torch.zeros_like(x_data) + + edge_index = remove_self_loops(data.edge_index)[0] + src, dst = edge_index[0], edge_index[1] + + c[order[seed]] = 1 + for idx in range(seed, min(dec_length, data.num_nodes)): + c[order[idx]] = 1 + + cTWc = torch.sum(c[src] * c[dst]) + if cTWc != 0: + c[order[idx]] = 0 + + is_size_list.append(c.sum()) + + data.is_size = max(is_size_list) + + return Batch.from_data_list(data_list) + + +def max_cut_size(x: Tensor, data: Batch) -> Tensor: + x = (x > 0).float() + x = (x - 0.5) * 2 + + x_list = unbatch(x, data.batch) + edge_index_list = unbatch_edge_index(data.edge_index, data.batch) + + cut_list = [] + for x, edge_index in zip(x_list, edge_index_list): + cut_list.append(torch.sum(x[edge_index[0]] * x[edge_index[1]] == -1.0) / 2) + + return Tensor(cut_list).mean() + + +# TODO: double-check implementation +def num_colors_used(x: Tensor, batch: Batch, num_seeds: int = 1) -> Tensor: + batch = graph_coloring_decoder(x, batch, num_seeds) + + data_list = batch.to_data_list() + + num_colors_used_list = [] + for data in data_list: + num_colors_used = data.colors.unique().size(0) + num_colors_used_list.append(num_colors_used) + + return torch.tensor(num_colors_used_list).mean(dtype=torch.float) + + +# TODO: double-check implementation +def graph_coloring_decoder(x: Tensor, batch: Batch, num_seeds: int = 1) -> Batch: + max_num_colors = x.size(1) + x = torch.sigmoid(x) + data_list = batch.to_data_list() + x_list = unbatch(x, batch.batch) + + for data, x_data in zip(data_list, x_list): + edge_index = remove_self_loops(data.edge_index)[0] + src, dst = edge_index[0], edge_index[1] + + best_colors = None + min_colors_used = max_num_colors + 1 # upper bound + + for seed in range(num_seeds): + order = torch.argsort(x_data.max(dim=1).values, descending=True) + colors = torch.full((data.num_nodes,), -1, dtype=torch.long, device=x_data.device) + + for idx in range(data.num_nodes): + node = order[(seed + idx) % data.num_nodes] + # Find available colors for this node + used = torch.zeros(max_num_colors, dtype=torch.bool, device=x_data.device) + neighbors = dst[src == node] + for neighbor in neighbors: + c = colors[neighbor] + if c >= 0: + used[c] = True + + # Assign the available color with the highest score in x_data for this node + available_color_indices = (~used).nonzero(as_tuple=True)[0] + max_idx = torch.argmax(x_data[node, available_color_indices]) + colors[node] = available_color_indices[max_idx] + + num_colors_used = colors.unique().size(0) + if num_colors_used < min_colors_used: + min_colors_used = num_colors_used + best_colors = colors + + data.colors = best_colors + + return Batch.from_data_list(data_list) + + +UNSUPERVISED_CO_METRICS = { + "mis_unsupervised": mis_size, + "cut_unsupervised": max_cut_size, + "chrom_unsupervised": num_colors_used, +} + +UNSUPERVISED_CO_METRIC_NAMES = { + "mis_unsupervised": "mis_size", + "cut_unsupervised": "max_cut_size", + "chrom_unsupervised": "chromatic_number", +} diff --git a/graphbench/helpers/unsupervised_losses.py b/graphbench/helpers/unsupervised_losses.py new file mode 100644 index 0000000..ea0bf71 --- /dev/null +++ b/graphbench/helpers/unsupervised_losses.py @@ -0,0 +1,49 @@ +# Source: https://github.com/WenkelF/copt/blob/main/graphgym/loss/copt_loss.py + +import torch +from torch import Tensor +from torch_geometric.data import Batch +from torch_geometric.utils import unbatch + + +def mis_loss(x: Tensor, batch: Batch, beta: float = 1.0) -> Tensor: + x = torch.sigmoid(x) + data_list = batch.to_data_list() + x_list = unbatch(x, batch.batch) + + loss = 0.0 + for data, x_data in zip(data_list, x_list): + src, dst = data.edge_index[0], data.edge_index[1] + + loss1 = torch.sum(x_data[src] * x_data[dst]) + loss2 = x_data.sum() ** 2 - loss1 - torch.sum(x_data ** 2) + loss += (- loss2 + beta * loss1) * data.num_nodes + + return loss / batch.size(0) + + +def max_cut_loss(x: Tensor, batch: Batch) -> Tensor: + x = torch.sigmoid(x) + x = (x - 0.5) * 2 + src, dst = batch.edge_index[0], batch.edge_index[1] + return torch.sum(x[src] * x[dst]) / len(batch.batch.unique()) + + +# Adapted from GCON. GCON implements this for a node feature matrix X of size [num_nodes, num_colors] and an adjacency +# matrix A of size [num_nodes, num_nodes]. It basically calculates sum(diag(X^T A X)) - 4 * sum(abs(X)). +# The implementation here replaces the adjacency matrix with a pytorch geometric graph. +# TODO double check implementation +def graph_coloring_loss(x: Tensor, batch: Batch) -> Tensor: + x = torch.sigmoid(x) + x = (x - 0.5) * 2 + src, dst = batch.edge_index + edge_loss = torch.sum(x[src] * x[dst]) + node_loss = 4 * torch.abs(x).sum() + return edge_loss - node_loss + + +UNSUPERVISED_CO_LOSSES = { + "mis_unsupervised": mis_loss, + "cut_unsupervised": max_cut_loss, + "chrom_unsupervised": graph_coloring_loss, +} diff --git a/graphbench/helpers/validate_solution.py b/graphbench/helpers/validate_solution.py new file mode 100644 index 0000000..ad94478 --- /dev/null +++ b/graphbench/helpers/validate_solution.py @@ -0,0 +1,59 @@ +from torch import Tensor +from torch_geometric.data import Data + + +# TODO We provide this functionality for users of this benchmark, but it's currently very difficult to discover that +# this exists. Needs to be documented somewhere! + + +def validate_mis_solution(graph: Data, solution: Tensor) -> bool: + """ + Checks whether the given solution is a valid independent set for the provided graph. + That is, no two nodes in the set are adjacent to each other. + + Parameters: + - `graph`: The problem graph + - `solution`: The independent set, as a binary vector where a 1 indicates that the node is in the set. + Size `[graph.num_nodes]` + """ + # check if solution is a binary vector of correct length + if solution.size() != (graph.num_nodes,) or not ((solution == 0) | (solution == 1)).all(): + return False + + # for each edge, see if both the source node and the destination node are in the set + src = graph.edge_index[0] + dst = graph.edge_index[1] + # this is non-zero if both nodes are in the set + both_in_set = solution[src] * solution[dst] + + # if any edge connects two nodes in the set, it is not a valid independent set + return not both_in_set.any() + + +def validate_max_cut_solution(graph: Data, solution: Tensor) -> bool: + """ + Always returns `True`, since any node subset leads to a valid cut. + """ + return True + + +def validate_chrom_solution(graph: Data, solution: Tensor) -> bool: + """ + Checks whether the given solution is a valid graph coloring for the provided graph. + That is, no two adjacent nodes are assigned the same color. + + Parameters: + - `graph`: The problem graph + - `solution`: The graph coloring, as a vector where each entry indicates the color assigned to the corresponding + node. Size `[graph.num_nodes]` + """ + if solution.size() != (graph.num_nodes,): + return False + + # for each edge, see if both nodes were assigned the same color + src = graph.edge_index[0] + dst = graph.edge_index[1] + same_color = solution[src] == solution[dst] + + # if any edge connects two nodes with the same color, then this is not a valid graph coloring + return not same_color.any()