Dataset Overview
The STGNN model uses resting-state fMRI (rs-fMRI) data from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database. The dataset consists of functional connectivity (FC) matrices derived from brain imaging scans across multiple visits per subject.
ADNI Access : Data must be requested from adni.loni.usc.edu . This project uses preprocessed FC matrices from ADNI’s rs-fMRI scans.
FC Matrix Dataset
Dataset Class
The FC_ADNIDataset class (from FC_ADNIDataset.py) is a PyTorch Geometric InMemoryDataset that loads functional connectivity matrices and converts them to graph structures.
Initialization
From FC_ADNIDataset.py:10-16:
class FC_ADNIDataset ( InMemoryDataset ):
def __init__ ( self , root , threshold = 0.2 , label_csv = 'TADPOLE_TEMPORAL.csv' ,
var_name = 'arr_0' , transform = None , pre_transform = None ):
self .threshold = threshold
self .label_csv = label_csv
self .var_name = var_name
self .subject_graph_dict = None
super (). __init__ (root, transform, pre_transform)
self .data, self .slices = torch.load( self .processed_paths[ 0 ], weights_only = False )
Parameters :
root: Data directory path (contains FC_Matrices/ and TADPOLE_TEMPORAL.csv)
threshold: Edge correlation threshold (default: 0.2)
label_csv: Label file name (default: TADPOLE_TEMPORAL.csv)
var_name: NPZ array key name (default: arr_0)
Usage Example
from FC_ADNIDataset import FC_ADNIDataset
# Load dataset
dataset = FC_ADNIDataset(
root = 'data/' ,
threshold = 0.2 ,
label_csv = 'TADPOLE_TEMPORAL.csv' ,
var_name = 'arr_0'
)
# Access graphs
print ( f "Total graphs: { len (dataset) } " )
first_graph = dataset[ 0 ]
print ( f "Nodes: { first_graph.x.shape[ 0 ] } " )
print ( f "Edges: { first_graph.edge_index.shape[ 1 ] } " )
print ( f "Label: { first_graph.y.item() } " )
print ( f "Subject ID: { first_graph.subj_id } " )
print ( f "Visit: { first_graph.visit_code } " )
print ( f "Months from baseline: { first_graph.visit_months } " )
FC Matrix Processing
Loading FC Matrices
From FC_ADNIDataset.py:84-115:
Scan directory
base_path = os.path.join( self .root, 'FC_Matrices' )
npz_files = [f for f in os.listdir(base_path) if f.endswith( '_fc_matrix.npz' )]
Searches for all .npz files ending with _fc_matrix.npz
Parse filename
# Format: sub-XXXXXX_run-XX_fc_matrix.npz
base_name = filename.replace( '_fc_matrix.npz' , '' )
parts = base_name.split( '_run-' )
subj_id = parts[ 0 ].replace( 'sub-' , '' )
run_num = parts[ 1 ] if len (parts) > 1 else '01'
Extracts subject ID and run number from filename
Load and symmetrize
data = np.load(file_path)
fc_matrix = data[ self .var_name]
fc_matrix = (fc_matrix + fc_matrix.T) / 2 # Ensure symmetry
Loads matrix and enforces symmetry
Convert to graph
full_id = f " { subj_id } _run { run_num } "
graph = self .fc_to_graph(fc_matrix, subj_id = full_id)
fc_graphs[full_id] = graph
Creates graph representation with subject ID
Graph Conversion
From FC_ADNIDataset.py:66-82:
def fc_to_graph ( self , matrix , node_features = None , subj_id = None ):
A = matrix.copy()
N = A.shape[ 0 ]
# Threshold edges
A[np.abs(A) < self .threshold] = 0
# Extract edges
edge_index = np.array(np.nonzero(A))
edge_attr = A[edge_index[ 0 ], edge_index[ 1 ]]
# Convert to PyTorch tensors
edge_index = torch.tensor(edge_index, dtype = torch.long)
edge_attr = torch.tensor(edge_attr, dtype = torch.float).unsqueeze( 1 )
# Node features (identity matrix by default)
if node_features is None :
x = torch.eye(N, dtype = torch.float) * 1.0
else :
x = torch.tensor(node_features, dtype = torch.float)
data = Data( x = x, edge_index = edge_index, edge_attr = edge_attr)
if subj_id is not None :
data.subj_id = subj_id
return data
Process :
Thresholding : Remove weak connections (|correlation| < threshold)
Edge extraction : Create edge list from remaining connections
Edge attributes : Correlation coefficients as edge weights
Node features : Identity matrix (one-hot encoding) by default
Graph Properties
Each graph has the following attributes:
Attribute Type Description Set at xTensor Node feature matrix (N × N identity) fc_to_graphedge_indexTensor Edge list (2 × E) fc_to_graphedge_attrTensor Edge weights/correlations (E × 1) fc_to_graphyTensor Subject label (0 or 1) processsubj_idstring Full subject ID (e.g., 123456_run01) processvisit_codestring Visit code (e.g., bl, m06) processvisit_monthsfloat Months from baseline processmonths_to_nextfloat Months to next visit process
Label Assignment
From FC_ADNIDataset.py:34-61:
def process ( self ):
base_path = os.path.join( self .root, 'FC_Matrices' )
label_path = os.path.join( self .root, self .label_csv)
# Load labels and visit information
label_dict, visit_dict = self .load_subject_labels_and_visits(label_path)
fc_graphs = self .load_fc_graphs(base_path)
data_list = []
for subj_id, graph in fc_graphs.items():
base_id = subj_id.split( '_run' )[ 0 ]
graph.y = torch.tensor([label_dict.get(base_id, 0 )], dtype = torch.long)
graph.subj_id = subj_id
# Add visit information if available
if subj_id in visit_dict:
visit_info = visit_dict[subj_id]
graph.visit_code = visit_info.get( 'visit_code' , 'unknown' )
graph.visit_months = visit_info.get( 'visit_months' , 0 )
graph.months_to_next = visit_info.get( 'months_to_next' , - 1 )
# ... fallback logic ...
data_list.append(graph)
Labels are assigned at the subject level (all visits for a subject get the same label from their last chronological visit).
From FC_ADNIDataset.py:123-158:
def load_subject_labels_and_visits ( self , label_csv_path , label_col = 'Label_CS_Num' ):
df = pd.read_csv(label_csv_path)
df[ 'Subject' ] = df[ 'Subject' ].str.replace( '_' , '' , regex = False )
# Sort by subject and visit order
df = df.sort_values([ 'Subject' , 'Visit_Order' ])
# Create label dictionary (use last visit's label)
label_dict = {}
for subject in df[ 'Subject' ].unique():
subject_data = df[df[ 'Subject' ] == subject]
label_dict[subject] = subject_data.iloc[ - 1 ][label_col]
# Create visit info dictionary
visit_dict = {}
for subject in df[ 'Subject' ].unique():
subject_data = df[df[ 'Subject' ] == subject].sort_values( 'Visit_Order' )
# Map each run number to chronologically ordered visits
for run_idx, (_, visit_row) in enumerate (subject_data.iterrows()):
run_key = f " { subject } _run { run_idx + 1 :02d} " # run01, run02, etc.
visit_dict[run_key] = {
'visit_code' : visit_row[ 'Visit' ],
'visit_months' : visit_row[ 'Months_From_Baseline' ],
'months_to_next' : visit_row.get( 'Months_To_Next_Original' , - 1 )
}
Key points :
Run numbers are mapped to chronologically ordered visits
run01 → first visit, run02 → second visit, etc.
Visit information includes visit code, months from baseline, and time to next visit
ROI Parcellation
While not explicitly defined in the code, ADNI FC matrices typically use standard brain atlases:
AAL Atlas 116 regions - Automated Anatomical Labeling atlasMost common for ADNI preprocessing
Custom Parcellation Variable size - Depends on preprocessing pipelineMatrix dimensions determined by ROI count
The model automatically detects matrix dimensions from the data (N × N where N = number of ROIs).
Edge Thresholding
Purpose
From FC_ADNIDataset.py:69:
A[np.abs(A) < self .threshold] = 0
Why threshold?
Removes weak or spurious connections
Creates sparse graphs for efficient GNN processing
Focuses on strong functional relationships
Default Threshold
Default value : 0.2 (20% correlation strength)
From README: threshold=0.2
Interpretation :
Edges with |correlation| < 0.2 are removed
Both positive and negative correlations are considered (absolute value)
Only moderate to strong connections remain
Adjusting Threshold
# Sparse graph (only strong connections)
dataset_sparse = FC_ADNIDataset( root = 'data/' , threshold = 0.5 )
# Dense graph (include weak connections)
dataset_dense = FC_ADNIDataset( root = 'data/' , threshold = 0.1 )
# No thresholding (keep all edges)
dataset_full = FC_ADNIDataset( root = 'data/' , threshold = 0.0 )
Lower thresholds create denser graphs, which increase memory usage and computation time. Higher thresholds may lose important connectivity information.
Data Statistics
Expected Dataset Size
From the README and typical ADNI usage:
Subjects : ~500-1700 subjects with rs-fMRI data
Total visits : ~2000-5000 individual scans
Multi-visit subjects : ~30-40% have longitudinal data
Average visits per subject : 1.5-3 visits
Memory Requirements
Per graph (116 ROIs, threshold=0.2):
Node features: 116 × 116 × 4 bytes ≈ 54 KB
Edge indices: ~1000 edges × 2 × 8 bytes ≈ 16 KB
Edge attributes: ~1000 edges × 4 bytes ≈ 4 KB
Total per graph : ~74 KB
Full dataset (3000 graphs):
In-memory: ~220 MB
Processed file (data.pt): ~300-400 MB
Example: Loading and Inspecting Data
import torch
from FC_ADNIDataset import FC_ADNIDataset
import numpy as np
# Load dataset
print ( "Loading ADNI FC dataset..." )
dataset = FC_ADNIDataset( root = 'data/' , threshold = 0.2 )
print ( f " \n Dataset Statistics:" )
print ( f "Total graphs: { len (dataset) } " )
print ( f "Number of node features: { dataset[ 0 ].x.shape[ 1 ] } " )
print ( f "Number of ROIs: { dataset[ 0 ].x.shape[ 0 ] } " )
# Analyze graph structure
edges_per_graph = [data.edge_index.shape[ 1 ] for data in dataset]
print ( f " \n Graph Connectivity:" )
print ( f "Average edges per graph: { np.mean(edges_per_graph) :.1f} " )
print ( f "Min edges: { np.min(edges_per_graph) } " )
print ( f "Max edges: { np.max(edges_per_graph) } " )
# Analyze labels
labels = [data.y.item() for data in dataset]
print ( f " \n Label Distribution:" )
print ( f "Stable (0): { labels.count( 0 ) } ( { 100 * labels.count( 0 ) / len (labels) :.1f} %)" )
print ( f "Converter (1): { labels.count( 1 ) } ( { 100 * labels.count( 1 ) / len (labels) :.1f} %)" )
# Analyze temporal information
visit_months = [data.visit_months for data in dataset if hasattr (data, 'visit_months' )]
print ( f " \n Temporal Information:" )
print ( f "Average months from baseline: { np.mean(visit_months) :.1f} " )
print ( f "Max follow-up time: { np.max(visit_months) :.1f} months" )
# Inspect a single graph
print ( f " \n Example Graph (index 0):" )
graph = dataset[ 0 ]
print ( f "Subject ID: { graph.subj_id } " )
print ( f "Visit code: { graph.visit_code } " )
print ( f "Label: { 'Converter' if graph.y.item() == 1 else 'Stable' } " )
print ( f "Nodes: { graph.x.shape[ 0 ] } " )
print ( f "Edges: { graph.edge_index.shape[ 1 ] } " )
print ( f "Edge weight range: [ { graph.edge_attr.min() :.3f} , { graph.edge_attr.max() :.3f} ]" )
Processed Data Caching
From FC_ADNIDataset.py:22-27:
@ property
def processed_file_names ( self ):
return [ 'data.pt' ]
def download ( self ):
pass
Caching behavior :
First load: Processes all FC matrices and creates data/processed/data.pt
Subsequent loads: Reads from cached data.pt (much faster)
Re-processing: Delete data/processed/ directory to force reprocessing
If you modify FC matrices or labels, delete the data/processed/ directory to regenerate the processed dataset: