glycontact.learning module¶
glycontact.learning
¶
GINSweetNet(lib_size: int, num_classes: int = 1, hidden_dim: int = 128, num_components: int = 5)
¶
Bases: Module
given glycan graphs as input, predicts properties via a graph neural network
forward(x: torch.Tensor, edge_index: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]
¶
Forward pass through the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input node features [batch_size, num_nodes, hidden_dim] |
required |
edge_index
|
Tensor
|
Edge indices for the graph [2, num_edges] |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple[Tensor, Tensor]
|
Tuple of |
|
weights_logits |
Tensor
|
Logits for mixture weights [batch_size, 2, num_components] |
means |
Tensor
|
Mean angles in degrees [batch_size, 2, num_components] |
kappas |
tuple[tuple[Tensor, Tensor], Tensor, Tensor]
|
Concentration parameters [batch_size, 2, num_components] |
sasa_pred |
tuple[tuple[Tensor, Tensor], Tensor, Tensor]
|
Predicted SASA values [batch_size] |
flex_pred |
tuple[tuple[Tensor, Tensor], Tensor, Tensor]
|
Predicted flexibility values [batch_size] |
VonMisesSweetNet(lib_size: int, num_classes: int = 1, hidden_dim: int = 128, num_components: int = 5)
¶
Bases: Module
given glycan graphs as input, predicts properties via a graph neural network
forward(x: torch.Tensor, edge_index: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]
¶
Forward pass through the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input node features [batch_size, num_nodes, hidden_dim] |
required |
edge_index
|
Tensor
|
Edge indices for the graph [2, num_edges] |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple[Tensor, Tensor, Tensor]
|
Tuple of |
|
weights_logits |
Tensor
|
Logits for mixture weights [batch_size, 2, num_components] |
means |
Tensor
|
Mean angles in degrees [batch_size, 2, num_components] |
kappas |
tuple[tuple[Tensor, Tensor, Tensor], Tensor, Tensor]
|
Concentration parameters [batch_size, 2, num_components] |
sasa_pred |
tuple[tuple[Tensor, Tensor, Tensor], Tensor, Tensor]
|
Predicted SASA values [batch_size] |
flex_pred |
tuple[tuple[Tensor, Tensor, Tensor], Tensor, Tensor]
|
Predicted flexibility values [batch_size] |
predict_von_mises_parameters(x: torch.Tensor, head: torch.nn.Module, fc_weights: torch.nn.Module, fc_means: torch.nn.Module, fc_kappas: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
¶
Predict mixture parameters for a given input tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input tensor [batch_size, hidden_dim] |
required |
head
|
Module
|
Head module for the mixture model |
required |
fc_weights
|
Module
|
Fully connected layer for weights |
required |
fc_means
|
Module
|
Fully connected layer for means |
required |
fc_kappas
|
Module
|
Fully connected layer for kappas |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor
|
Tuple of |
|
weights_logits |
Tensor
|
Logits for mixture weights [batch_size, 2, num_components] |
means |
Tensor
|
Mean angles in degrees [batch_size, 2, num_components] |
kappas |
tuple[Tensor, Tensor, Tensor]
|
Concentration parameters [batch_size, 2, num_components] |
angular_rmse(predicted_graphs: list[nx.DiGraph], true_graphs: list[nx.DiGraph]) -> tuple[float, float]
¶
Calculate the root mean square error (RMSE) for phi and psi angles.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
predicted_graphs
|
list[DiGraph]
|
List of predicted structure graphs |
required |
true_graphs
|
list[DiGraph]
|
List of true structure graphs |
required |
Returns:
Type | Description |
---|---|
tuple[float, float]
|
Tuple of RMSE for phi and psi angles |
build_baselines(data: list[nx.DiGraph], fn: callable = np.mean) -> tuple[callable, callable, callable, callable]
¶
Build baseline functions to predict SASA, flexibility, phi, and psi angles based on monosaccharides.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
list[DiGraph]
|
List of structure graphs. |
required |
fn
|
callable
|
Function to aggregate values (e.g., np.mean, np.median). |
mean
|
Returns:
Type | Description |
---|---|
tuple[callable, callable, callable, callable]
|
Tuple of functions for phi, psi, SASA, and flexibility. |
clean_split(split: list[tuple[torch_geometric.data.Data, nx.DiGraph]], mode: Literal['mean', 'max'] = 'max') -> tuple[torch_geometric.data.Data, nx.DiGraph]
¶
Clean the split data by condensing it to one conformer per glycan.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
split
|
list
|
A list of tuples containing the PyTorch Geometric Data object and the structure graph. |
required |
mode
|
str
|
The mode for condensing the data. "mean" for mean conformer, "max" for maximum weight conformer. |
'max'
|
Returns:
Name | Type | Description |
---|---|---|
list |
tuple[Data, DiGraph]
|
A list of tuples containing the condensed PyTorch Geometric Data object and the structure graph. |
create_dataset(fresh: bool = True)
¶
Create a dataset of PyTorch Geometric Data objects from the structure graphs of glycans.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fresh
|
bool
|
If True, fetches the latest data. If False, uses cached data. |
True
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
A tuple containing the training and testing datasets. |
eval_baseline(nxgraphs: list[nx.DiGraph], phi_pred: callable, psi_pred: callable, sasa_pred: callable, flex_pred: callable) -> list[nx.DiGraph]
¶
Evaluate the baseline model by predicting angles and properties for each graph.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nxgraphs
|
list[DiGraph]
|
List of structure graphs |
required |
phi_pred
|
callable
|
Function to predict phi angles |
required |
psi_pred
|
callable
|
Function to predict psi angles |
required |
sasa_pred
|
callable
|
Function to predict SASA |
required |
flex_pred
|
callable
|
Function to predict flexibility |
required |
Returns:
Type | Description |
---|---|
list[DiGraph]
|
List of predicted structure graphs |
evaluate_model(model: torch.nn.Module | tuple[callable, callable, callable, callable], structures, count: int = 10)
¶
Evaluate the model by sampling angles and properties from the structure graphs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module | tuple[callable, callable, callable, callable]
|
The trained model. This can be a trained SweetNet or a tuple of baseline predictors for phi, psi, SASA, and flexibility. |
required |
structures
|
List of structure graphs |
required | |
count
|
int
|
Number of samples to generate for each graph |
10
|
Returns:
Type | Description |
---|---|
Tuple of RMSE values for phi, psi, SASA, and flexibility |
get_all_structure_graphs(glycan, stereo=None, libr=None)
¶
Get all structure graphs for a given glycan.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
glycan
|
str
|
The glycan name. |
required |
stereo
|
str
|
The stereochemistry. If None, both alpha and beta are returned. |
None
|
libr
|
HashableDict
|
A library of structures. If None, the default library is used. |
None
|
Returns:
Name | Type | Description |
---|---|---|
list |
A list of tuples containing the PDB file name and the corresponding structure graph. |
graph2pyg(g, weight, iupac, conformer)
¶
Convert a structure graph to a PyTorch Geometric Data object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
g
|
Graph
|
The structure graph. |
required |
weight
|
float
|
The weight of the graph. |
required |
iupac
|
str
|
The IUPAC name of the glycan. |
required |
conformer
|
str
|
The conformer name. |
required |
Returns:
Type | Description |
---|---|
torch_geometric.data.Data: The PyTorch Geometric Data object. |
mean_conformer(conformers: list[tuple[float, tuple[torch_geometric.data.Data, nx.DiGraph]]]) -> tuple[torch_geometric.data.Data, nx.DiGraph]
¶
Calculate the mean conformer from a list of conformers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
conformers
|
list
|
A list of tuples containing the weight and the structure graph. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple[Data, DiGraph]
|
A tuple containing the mean PyTorch Geometric Data object and the mean structure graph. |
mixture_von_mises_nll(angles: torch.Tensor, weights_logits: torch.Tensor, mus: torch.Tensor, kappas: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]
¶
Negative log-likelihood for mixture of von Mises distributions
Parameters:
Name | Type | Description | Default |
---|---|---|---|
angles
|
Tensor
|
True angles in degrees [batch_size, 2] (phi, psi) |
required |
weights_logits
|
Tensor
|
Raw logits for mixture weights [batch_size, 2, n_components] |
required |
mus
|
Tensor
|
Mean angles in degrees [batch_size, 2, n_components] |
required |
kappas
|
Tensor
|
Concentration parameters [batch_size, 2, n_components] |
required |
Returns:
Type | Description |
---|---|
tuple[Tensor, Tensor]
|
Negative log-likelihood |
node2y(attr)
¶
Extract ML task labels from node attributes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
attr
|
dict
|
Node attributes. |
required |
Returns:
Name | Type | Description |
---|---|---|
list |
A list of labels for the node. If all labels are zero, returns None. |
periodic_mse(pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]
¶
Calculate the periodic mean squared error (MSE) for angles.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred
|
Tensor
|
Predicted angles in degrees [batch_size, 2] |
required |
target
|
Tensor
|
True angles in degrees [batch_size, 2] |
required |
Returns:
Type | Description |
---|---|
tuple[Tensor, Tensor]
|
Tuple of MSE for phi and psi angles |
periodic_rmse(pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]
¶
Calculate the periodic root mean square error (RMSE) for angles.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred
|
Tensor
|
Predicted angles in degrees [batch_size, 2] |
required |
target
|
Tensor
|
True angles in degrees [batch_size, 2] |
required |
Returns:
Type | Description |
---|---|
tuple[Tensor, Tensor]
|
Tuple of RMSE for phi and psi angles |
sample_angle(weights: torch.Tensor, mus: torch.Tensor, kappas: torch.Tensor) -> torch.Tensor
¶
Sample an angle from a mixture of von Mises distributions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights
|
Tensor
|
Mixture weights [n_components] |
required |
mus
|
Tensor
|
Mean angles in degrees [n_components] |
required |
kappas
|
Tensor
|
Concentration parameters [n_components] |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Sampled angle in degrees |
sample_from_model(model: torch.nn.Module, structures: list[torch_geometric.data.Data, nx.DiGraph], count: int = 10)
¶
Sample from the model using the provided structures
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The trained model |
required |
structures
|
list[Data, DiGraph]
|
List of structure graphs |
required |
Returns:
Type | Description |
---|---|
List of sampled angles |
value_rmse(predicted_graphs: list[nx.DiGraph], true_graphs: list[nx.DiGraph], name: Literal['SASA', 'flexibility']) -> float
¶
Calculate the root mean square error (RMSE) for a specific property (SASA or flexibility).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
predicted_graphs
|
list[DiGraph]
|
List of predicted structure graphs |
required |
true_graphs
|
list[DiGraph]
|
List of true structure graphs |
required |
name
|
Literal['SASA', 'flexibility']
|
The property to calculate RMSE for (e.g., "SASA" or "flexibility") |
required |
Returns:
Type | Description |
---|---|
float
|
RMSE value |