Skip to content

glycontact.learning module

glycontact.learning

GINSweetNet(lib_size: int, num_classes: int = 1, hidden_dim: int = 128, num_components: int = 5)

Bases: Module

given glycan graphs as input, predicts properties via a graph neural network

forward(x: torch.Tensor, edge_index: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]

Forward pass through the model.

Parameters:

Name Type Description Default
x Tensor

Input node features [batch_size, num_nodes, hidden_dim]

required
edge_index Tensor

Edge indices for the graph [2, num_edges]

required

Returns:

Name Type Description
tuple[Tensor, Tensor]

Tuple of

weights_logits Tensor

Logits for mixture weights [batch_size, 2, num_components]

means Tensor

Mean angles in degrees [batch_size, 2, num_components]

kappas tuple[tuple[Tensor, Tensor], Tensor, Tensor]

Concentration parameters [batch_size, 2, num_components]

sasa_pred tuple[tuple[Tensor, Tensor], Tensor, Tensor]

Predicted SASA values [batch_size]

flex_pred tuple[tuple[Tensor, Tensor], Tensor, Tensor]

Predicted flexibility values [batch_size]

VonMisesSweetNet(lib_size: int, num_classes: int = 1, hidden_dim: int = 128, num_components: int = 5)

Bases: Module

given glycan graphs as input, predicts properties via a graph neural network

forward(x: torch.Tensor, edge_index: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]

Forward pass through the model.

Parameters:

Name Type Description Default
x Tensor

Input node features [batch_size, num_nodes, hidden_dim]

required
edge_index Tensor

Edge indices for the graph [2, num_edges]

required

Returns:

Name Type Description
tuple[Tensor, Tensor, Tensor]

Tuple of

weights_logits Tensor

Logits for mixture weights [batch_size, 2, num_components]

means Tensor

Mean angles in degrees [batch_size, 2, num_components]

kappas tuple[tuple[Tensor, Tensor, Tensor], Tensor, Tensor]

Concentration parameters [batch_size, 2, num_components]

sasa_pred tuple[tuple[Tensor, Tensor, Tensor], Tensor, Tensor]

Predicted SASA values [batch_size]

flex_pred tuple[tuple[Tensor, Tensor, Tensor], Tensor, Tensor]

Predicted flexibility values [batch_size]

predict_von_mises_parameters(x: torch.Tensor, head: torch.nn.Module, fc_weights: torch.nn.Module, fc_means: torch.nn.Module, fc_kappas: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Predict mixture parameters for a given input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor [batch_size, hidden_dim]

required
head Module

Head module for the mixture model

required
fc_weights Module

Fully connected layer for weights

required
fc_means Module

Fully connected layer for means

required
fc_kappas Module

Fully connected layer for kappas

required

Returns:

Name Type Description
Tensor

Tuple of

weights_logits Tensor

Logits for mixture weights [batch_size, 2, num_components]

means Tensor

Mean angles in degrees [batch_size, 2, num_components]

kappas tuple[Tensor, Tensor, Tensor]

Concentration parameters [batch_size, 2, num_components]

angular_rmse(predicted_graphs: list[nx.DiGraph], true_graphs: list[nx.DiGraph]) -> tuple[float, float]

Calculate the root mean square error (RMSE) for phi and psi angles.

Parameters:

Name Type Description Default
predicted_graphs list[DiGraph]

List of predicted structure graphs

required
true_graphs list[DiGraph]

List of true structure graphs

required

Returns:

Type Description
tuple[float, float]

Tuple of RMSE for phi and psi angles

build_baselines(data: list[nx.DiGraph], fn: callable = np.mean) -> tuple[callable, callable, callable, callable]

Build baseline functions to predict SASA, flexibility, phi, and psi angles based on monosaccharides.

Parameters:

Name Type Description Default
data list[DiGraph]

List of structure graphs.

required
fn callable

Function to aggregate values (e.g., np.mean, np.median).

mean

Returns:

Type Description
tuple[callable, callable, callable, callable]

Tuple of functions for phi, psi, SASA, and flexibility.

clean_split(split: list[tuple[torch_geometric.data.Data, nx.DiGraph]], mode: Literal['mean', 'max'] = 'max') -> tuple[torch_geometric.data.Data, nx.DiGraph]

Clean the split data by condensing it to one conformer per glycan.

Parameters:

Name Type Description Default
split list

A list of tuples containing the PyTorch Geometric Data object and the structure graph.

required
mode str

The mode for condensing the data. "mean" for mean conformer, "max" for maximum weight conformer.

'max'

Returns:

Name Type Description
list tuple[Data, DiGraph]

A list of tuples containing the condensed PyTorch Geometric Data object and the structure graph.

create_dataset(fresh: bool = True)

Create a dataset of PyTorch Geometric Data objects from the structure graphs of glycans.

Parameters:

Name Type Description Default
fresh bool

If True, fetches the latest data. If False, uses cached data.

True

Returns:

Name Type Description
tuple

A tuple containing the training and testing datasets.

eval_baseline(nxgraphs: list[nx.DiGraph], phi_pred: callable, psi_pred: callable, sasa_pred: callable, flex_pred: callable) -> list[nx.DiGraph]

Evaluate the baseline model by predicting angles and properties for each graph.

Parameters:

Name Type Description Default
nxgraphs list[DiGraph]

List of structure graphs

required
phi_pred callable

Function to predict phi angles

required
psi_pred callable

Function to predict psi angles

required
sasa_pred callable

Function to predict SASA

required
flex_pred callable

Function to predict flexibility

required

Returns:

Type Description
list[DiGraph]

List of predicted structure graphs

evaluate_model(model: torch.nn.Module | tuple[callable, callable, callable, callable], structures, count: int = 10)

Evaluate the model by sampling angles and properties from the structure graphs.

Parameters:

Name Type Description Default
model Module | tuple[callable, callable, callable, callable]

The trained model. This can be a trained SweetNet or a tuple of baseline predictors for phi, psi, SASA, and flexibility.

required
structures

List of structure graphs

required
count int

Number of samples to generate for each graph

10

Returns:

Type Description

Tuple of RMSE values for phi, psi, SASA, and flexibility

get_all_structure_graphs(glycan, stereo=None, libr=None)

Get all structure graphs for a given glycan.

Parameters:

Name Type Description Default
glycan str

The glycan name.

required
stereo str

The stereochemistry. If None, both alpha and beta are returned.

None
libr HashableDict

A library of structures. If None, the default library is used.

None

Returns:

Name Type Description
list

A list of tuples containing the PDB file name and the corresponding structure graph.

graph2pyg(g, weight, iupac, conformer)

Convert a structure graph to a PyTorch Geometric Data object.

Parameters:

Name Type Description Default
g Graph

The structure graph.

required
weight float

The weight of the graph.

required
iupac str

The IUPAC name of the glycan.

required
conformer str

The conformer name.

required

Returns:

Type Description

torch_geometric.data.Data: The PyTorch Geometric Data object.

mean_conformer(conformers: list[tuple[float, tuple[torch_geometric.data.Data, nx.DiGraph]]]) -> tuple[torch_geometric.data.Data, nx.DiGraph]

Calculate the mean conformer from a list of conformers.

Parameters:

Name Type Description Default
conformers list

A list of tuples containing the weight and the structure graph.

required

Returns:

Name Type Description
tuple tuple[Data, DiGraph]

A tuple containing the mean PyTorch Geometric Data object and the mean structure graph.

mixture_von_mises_nll(angles: torch.Tensor, weights_logits: torch.Tensor, mus: torch.Tensor, kappas: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Negative log-likelihood for mixture of von Mises distributions

Parameters:

Name Type Description Default
angles Tensor

True angles in degrees [batch_size, 2] (phi, psi)

required
weights_logits Tensor

Raw logits for mixture weights [batch_size, 2, n_components]

required
mus Tensor

Mean angles in degrees [batch_size, 2, n_components]

required
kappas Tensor

Concentration parameters [batch_size, 2, n_components]

required

Returns:

Type Description
tuple[Tensor, Tensor]

Negative log-likelihood

node2y(attr)

Extract ML task labels from node attributes.

Parameters:

Name Type Description Default
attr dict

Node attributes.

required

Returns:

Name Type Description
list

A list of labels for the node. If all labels are zero, returns None.

periodic_mse(pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Calculate the periodic mean squared error (MSE) for angles.

Parameters:

Name Type Description Default
pred Tensor

Predicted angles in degrees [batch_size, 2]

required
target Tensor

True angles in degrees [batch_size, 2]

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of MSE for phi and psi angles

periodic_rmse(pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Calculate the periodic root mean square error (RMSE) for angles.

Parameters:

Name Type Description Default
pred Tensor

Predicted angles in degrees [batch_size, 2]

required
target Tensor

True angles in degrees [batch_size, 2]

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of RMSE for phi and psi angles

sample_angle(weights: torch.Tensor, mus: torch.Tensor, kappas: torch.Tensor) -> torch.Tensor

Sample an angle from a mixture of von Mises distributions.

Parameters:

Name Type Description Default
weights Tensor

Mixture weights [n_components]

required
mus Tensor

Mean angles in degrees [n_components]

required
kappas Tensor

Concentration parameters [n_components]

required

Returns:

Type Description
Tensor

Sampled angle in degrees

sample_from_model(model: torch.nn.Module, structures: list[torch_geometric.data.Data, nx.DiGraph], count: int = 10)

Sample from the model using the provided structures

Parameters:

Name Type Description Default
model Module

The trained model

required
structures list[Data, DiGraph]

List of structure graphs

required

Returns:

Type Description

List of sampled angles

value_rmse(predicted_graphs: list[nx.DiGraph], true_graphs: list[nx.DiGraph], name: Literal['SASA', 'flexibility']) -> float

Calculate the root mean square error (RMSE) for a specific property (SASA or flexibility).

Parameters:

Name Type Description Default
predicted_graphs list[DiGraph]

List of predicted structure graphs

required
true_graphs list[DiGraph]

List of true structure graphs

required
name Literal['SASA', 'flexibility']

The property to calculate RMSE for (e.g., "SASA" or "flexibility")

required

Returns:

Type Description
float

RMSE value