Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

# These names will also be recognized as section headers in Google-style docstrings,
# in addition to the default ones like "Args", "Returns", etc.
napoleon_custom_sections = ["Overview", "Graph Attributes", "List of Available Datasets", "Splits", "Usage Notes"]
napoleon_custom_sections = ["Overview", "Helpers", "Graph Attributes", "List of Available Datasets", "Splits", "Usage Notes"]


# -- MathJax configuration ---------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion graphbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._loader import Loader
from ._evaluator import Evaluator
#from ._optimize import Optimizer
from . import helpers


__all__ = ["datasets", "Loader", "Evaluator"]
__all__ = ["datasets", "helpers", "Loader", "Evaluator"]
118 changes: 118 additions & 0 deletions graphbench/datasets/_combinatorial_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,124 @@ class CODataset(GraphDataset):

Please refer to the `GraphBench paper <https://arxiv.org/abs/2512.04475>`__ for the exact parameters used for graph generation.

Helpers:
The following helper functions are available under ``graphbench.helpers``.
They are optional but reduce boilerplate for unsupervised CO training and
evaluation. Each loss refers to its matching decoder and metric.

**Losses**

- :func:`graphbench.helpers.mis_loss` - Unsupervised loss function to
train a model for MIS.

At test time, use :func:`graphbench.helpers.mis_decoder` to convert the
model's soft output to a discrete solution, and
:func:`graphbench.helpers.mis_size` to evaluate the model's performance.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes]``.
- ``batch`` (Batch): PyG batch with the input graphs.
- ``beta`` (float, optional): Edge penalty weight, default to 1.0.

- :func:`graphbench.helpers.max_cut_loss` - Unsupervised loss function to
train a model for max-cut.

At test time, use :func:`graphbench.helpers.max_cut_size` to evaluate the
model's performance.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes]``.
- ``batch`` (Batch): PyG batch with the input graphs.

- :func:`graphbench.helpers.graph_coloring_loss` - Unsupervised loss
function to train a model for graph coloring.

At test time, use :func:`graphbench.helpers.graph_coloring_decoder` to
convert the model's soft output to a discrete solution, and
:func:`graphbench.helpers.num_colors_used` to evaluate the model's
performance.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes, num_colors]``.
- ``batch`` (Batch): PyG batch with the input graphs.

**Decoders**

- :func:`graphbench.helpers.mis_decoder` - Converts the model's soft
prediction to a discrete solution to the MIS problem.

This can be used at test time for models trained with
:func:`graphbench.helpers.mis_loss`.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes]``.
- ``batch`` (Batch): PyG batch with the input graphs.
- ``dec_length`` (int, optional): Number of decoding steps, default to 300.
- ``num_seeds`` (int, optional): Number of decoding restarts, default to 1.

- :func:`graphbench.helpers.graph_coloring_decoder` - Converts the model's
soft prediction to a discrete solution to the graph coloring problem.

This can be used at test time for models trained with
:func:`graphbench.helpers.graph_coloring_loss`.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes, num_colors]``.
- ``batch`` (Batch): PyG batch with the input graphs.
- ``num_seeds`` (int, optional): Number of decoding restarts, default to 1.

**Metrics**

- :func:`graphbench.helpers.mis_size` - Computes MIS size from a decoded
solution.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes]``.
- ``batch`` (Batch): PyG batch with the input graphs.
- ``dec_length`` (int, optional): Number of decoding steps, default to 300.
- ``num_seeds`` (int, optional): Number of decoding restarts, default to 1.

- :func:`graphbench.helpers.max_cut_size` - Computes max-cut size from a
thresholded cut assignment.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes]``.
- ``batch`` (Batch): PyG batch with the input graphs.

- :func:`graphbench.helpers.num_colors_used` - Computes the number of
colors used by a decoded coloring. Uses
:func:`graphbench.helpers.graph_coloring_decoder` internally.

Parameters:
- ``x`` (Tensor): Soft model output of shape ``[num_nodes, num_colors]``.
- ``batch`` (Batch): PyG batch with the input graphs.
- ``num_seeds`` (int, optional): Number of decoding restarts, default to 1.

**Validators**

- :func:`graphbench.helpers.validate_mis_solution` - Checks whether a
the given solution is a valid independent set for the provided graph.

Parameters:
- ``graph`` (Data): The problem graph.
- ``solution`` (Tensor): The independent set of shape ``[num_nodes]``, as a binary vector where a 1 indicates that the node is in the set.

- :func:`graphbench.helpers.validate_max_cut_solution` - Always returns
``True`` because any partition defines a valid cut.

Parameters:
- ``graph`` (Data): The problem graph.
- ``solution`` (Tensor): Binary node indicators.

- :func:`graphbench.helpers.validate_chrom_solution` - Checks whether the
given solution is a valid graph coloring for the provided graph.

Parameters:
- ``graph`` (Data): The problem graph.
- ``solution`` (Tensor): The graph coloring of shape ``[num_nodes]``, as a vector where each entry indicates the color assigned to the corresponding
node.


Splits:
All datasets use a 70% / 15% / 15% split for training, validation,
and testing.
Expand Down
50 changes: 50 additions & 0 deletions graphbench/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from graphbench._helpers import (
download_and_unpack,
get_logger,
split_dataset,
SourceSpec,
VectorizedCircuitSimulator,
)
from .decoders import (
graph_coloring_decoder,
mis_decoder,
mis_size,
max_cut_size,
num_colors_used,
UNSUPERVISED_CO_METRICS,
UNSUPERVISED_CO_METRIC_NAMES,
)
from .unsupervised_losses import (
graph_coloring_loss,
max_cut_loss,
mis_loss,
UNSUPERVISED_CO_LOSSES,
)
from .validate_solution import (
validate_chrom_solution,
validate_max_cut_solution,
validate_mis_solution,
)


__all__ = [
"download_and_unpack",
"get_logger",
"split_dataset",
"SourceSpec",
"VectorizedCircuitSimulator",
"graph_coloring_decoder",
"mis_decoder",
"mis_size",
"max_cut_size",
"num_colors_used",
"UNSUPERVISED_CO_METRICS",
"UNSUPERVISED_CO_METRIC_NAMES",
"graph_coloring_loss",
"max_cut_loss",
"mis_loss",
"UNSUPERVISED_CO_LOSSES",
"validate_chrom_solution",
"validate_max_cut_solution",
"validate_mis_solution",
]
131 changes: 131 additions & 0 deletions graphbench/helpers/decoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Source for mis and max cut: https://github.com/WenkelF/copt/blob/main/utils/metrics.py

import torch
from torch import Tensor
from torch_geometric.data import Batch
from torch_geometric.utils import unbatch, unbatch_edge_index, remove_self_loops


def mis_size(x: Tensor, batch: Batch, dec_length: int = 300, num_seeds: int = 1) -> Tensor:
batch = mis_decoder(x, batch, dec_length, num_seeds)

data_list = batch.to_data_list()

size_list = [data.is_size for data in data_list]

return Tensor(size_list).mean()


def mis_decoder(x: Tensor, batch: Batch, dec_length: int = 300, num_seeds: int = 1) -> Batch:
x = torch.sigmoid(x)
data_list = batch.to_data_list()
x_list = unbatch(x, batch.batch)

for data, x_data in zip(data_list, x_list):
is_size_list = []

for seed in range(num_seeds):

order = torch.argsort(x_data, dim=0, descending=True)
c = torch.zeros_like(x_data)

edge_index = remove_self_loops(data.edge_index)[0]
src, dst = edge_index[0], edge_index[1]

c[order[seed]] = 1
for idx in range(seed, min(dec_length, data.num_nodes)):
c[order[idx]] = 1

cTWc = torch.sum(c[src] * c[dst])
if cTWc != 0:
c[order[idx]] = 0

is_size_list.append(c.sum())

data.is_size = max(is_size_list)

return Batch.from_data_list(data_list)


def max_cut_size(x: Tensor, data: Batch) -> Tensor:
x = (x > 0).float()
x = (x - 0.5) * 2

x_list = unbatch(x, data.batch)
edge_index_list = unbatch_edge_index(data.edge_index, data.batch)

cut_list = []
for x, edge_index in zip(x_list, edge_index_list):
cut_list.append(torch.sum(x[edge_index[0]] * x[edge_index[1]] == -1.0) / 2)

return Tensor(cut_list).mean()


# TODO: double-check implementation
def num_colors_used(x: Tensor, batch: Batch, num_seeds: int = 1) -> Tensor:
batch = graph_coloring_decoder(x, batch, num_seeds)

data_list = batch.to_data_list()

num_colors_used_list = []
for data in data_list:
num_colors_used = data.colors.unique().size(0)
num_colors_used_list.append(num_colors_used)

return torch.tensor(num_colors_used_list).mean(dtype=torch.float)


# TODO: double-check implementation
def graph_coloring_decoder(x: Tensor, batch: Batch, num_seeds: int = 1) -> Batch:
max_num_colors = x.size(1)
x = torch.sigmoid(x)
data_list = batch.to_data_list()
x_list = unbatch(x, batch.batch)

for data, x_data in zip(data_list, x_list):
edge_index = remove_self_loops(data.edge_index)[0]
src, dst = edge_index[0], edge_index[1]

best_colors = None
min_colors_used = max_num_colors + 1 # upper bound

for seed in range(num_seeds):
order = torch.argsort(x_data.max(dim=1).values, descending=True)
colors = torch.full((data.num_nodes,), -1, dtype=torch.long, device=x_data.device)

for idx in range(data.num_nodes):
node = order[(seed + idx) % data.num_nodes]
# Find available colors for this node
used = torch.zeros(max_num_colors, dtype=torch.bool, device=x_data.device)
neighbors = dst[src == node]
for neighbor in neighbors:
c = colors[neighbor]
if c >= 0:
used[c] = True

# Assign the available color with the highest score in x_data for this node
available_color_indices = (~used).nonzero(as_tuple=True)[0]
max_idx = torch.argmax(x_data[node, available_color_indices])
colors[node] = available_color_indices[max_idx]

num_colors_used = colors.unique().size(0)
if num_colors_used < min_colors_used:
min_colors_used = num_colors_used
best_colors = colors

data.colors = best_colors

return Batch.from_data_list(data_list)


UNSUPERVISED_CO_METRICS = {
"mis_unsupervised": mis_size,
"cut_unsupervised": max_cut_size,
"chrom_unsupervised": num_colors_used,
}

UNSUPERVISED_CO_METRIC_NAMES = {
"mis_unsupervised": "mis_size",
"cut_unsupervised": "max_cut_size",
"chrom_unsupervised": "chromatic_number",
}
49 changes: 49 additions & 0 deletions graphbench/helpers/unsupervised_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Source: https://github.com/WenkelF/copt/blob/main/graphgym/loss/copt_loss.py

import torch
from torch import Tensor
from torch_geometric.data import Batch
from torch_geometric.utils import unbatch


def mis_loss(x: Tensor, batch: Batch, beta: float = 1.0) -> Tensor:
x = torch.sigmoid(x)
data_list = batch.to_data_list()
x_list = unbatch(x, batch.batch)

loss = 0.0
for data, x_data in zip(data_list, x_list):
src, dst = data.edge_index[0], data.edge_index[1]

loss1 = torch.sum(x_data[src] * x_data[dst])
loss2 = x_data.sum() ** 2 - loss1 - torch.sum(x_data ** 2)
loss += (- loss2 + beta * loss1) * data.num_nodes

return loss / batch.size(0)


def max_cut_loss(x: Tensor, batch: Batch) -> Tensor:
x = torch.sigmoid(x)
x = (x - 0.5) * 2
src, dst = batch.edge_index[0], batch.edge_index[1]
return torch.sum(x[src] * x[dst]) / len(batch.batch.unique())


# Adapted from GCON. GCON implements this for a node feature matrix X of size [num_nodes, num_colors] and an adjacency
# matrix A of size [num_nodes, num_nodes]. It basically calculates sum(diag(X^T A X)) - 4 * sum(abs(X)).
# The implementation here replaces the adjacency matrix with a pytorch geometric graph.
# TODO double check implementation
def graph_coloring_loss(x: Tensor, batch: Batch) -> Tensor:
x = torch.sigmoid(x)
x = (x - 0.5) * 2
src, dst = batch.edge_index
edge_loss = torch.sum(x[src] * x[dst])
node_loss = 4 * torch.abs(x).sum()
return edge_loss - node_loss


UNSUPERVISED_CO_LOSSES = {
"mis_unsupervised": mis_loss,
"cut_unsupervised": max_cut_loss,
"chrom_unsupervised": graph_coloring_loss,
}
Loading