Source code for conn2res.connectivity

"""
Functionality for connectivity matrix
"""
import os
import numpy as np
import warnings
from scipy.linalg import eigh
from bct.algorithms.clustering import get_components
from bct.algorithms.distance import distance_bin

from .utils import *


PROJ_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_DIR = os.path.join(PROJ_DIR, 'examples', 'data')


[docs]class Conn: """ Class that represents a connectivity matrix representing either weighted or unweighted connectivity data Attributes ---------- # TODO Methods ---------- # TODO """ def __init__(self, filename=None, subj_id=None, w=None, modules=None, density=None): if w is not None: # assign provided connectivity data self.w = w else: # load connectivity data if filename is not None: self.w = np.load(filename) else: self.w = load_file('connectivity.npy') # select one subject from group connectivity data if subj_id is not None and self.w.ndim == 3: self.w = self.w[:, :, subj_id] # set zero diagonal np.fill_diagonal(self.w, 0) # remove inf and nan self.w[np.logical_or(np.isinf(self.w), np.isnan(self.w))] = 0 # make sure weights are float self.w = self.w.astype(float) # number of all active nodes self.n_nodes = len(self.w) # number of edges (in symmetric networks edges are counted twice!) self.n_edges = np.sum(self.w != 0) # check if network is symmetric (needed e.g. for checking connectedness) self.symmetric = check_symmetric(self.w) # use fixed density if set if density is not None: if self.symmetric: nedges = int(self.n_nodes * (self.n_nodes - 1) * density // 2) id_ = np.argsort(np.triu(self.w, 1), axis=None) self.w[np.unravel_index(id_[:-nedges], self.w.shape)] = 0 self.w = make_symmetric(self.w, copy_lower=False) else: nedges = int(self.n_nodes * (self.n_nodes - 1) * density) id_ = np.argsort(self.w, axis=None) self.w[np.unravel_index(id_[:-nedges], self.w.shape)] = 0 # density of network self.density = self.n_edges / (self.n_nodes * (self.n_nodes - 1)) # indexes of set of active nodes self.idx_node = np.full(self.n_nodes, True) # make sure that all nodes are connected to the rest of the network self.subset_nodes(idx_node=np.logical_or( np.any(self.w != 0, axis=0), np.any(self.w != 0, axis=1))) # assign modules self.modules = modules def scale_and_normalize(self): """ Scales the connectivity matrix between [0, 1] and divides by spectral radius # TODO """ # scale connectivity matrix between [0, 1] self.scale() # divide connectivity matrix by spectral radius self.normalize() def scale(self): """ Scales the connectivity matrix between [0, 1] # TODO """ # scale connectivity matrix between [0, 1] self.w = (self.w - self.w.min()) / (self.w.max() - self.w.min()) def normalize(self): """ Normalizes the connectivity matrix with spectral radius # TODO """ # divide connectivity matrix by spectral radius ew, _ = eigh(self.w) self.w = self.w / np.abs(ew).max() def binarize(self): """ Binarizes the connectivity matrix # TODO """ # binarize connectivity matrix self.w = self.w.astype(bool).astype(float) def add_weights(self, w, mask='triu', order='random'): """ Add weights to either a binary or weighted connectivity matrix # TODO """ if mask == 'full': if w.size != self.n_edges: raise ValueError( 'number of elements in mask and w do not match') # add weights to full matrix if order == 'random': self.w[self.w != 0] = w elif order == 'absrank': # keep absolute rank of weights id_ = np.argsort(np.abs(w)) w = w[id_[::-1]] id_ = np.argsort(np.abs(self.w), axis=None) self.w[np.unravel_index(id_[:-w.size-1:-1], self.w.shape)] = w elif order == 'rank': # keep rank of weights id_ = np.argsort(w) w = w[id_[::-1]] id_ = np.argsort(self.w, axis=None) self.w[np.unravel_index(id_[:-w.size-1:-1], self.w.shape)] = w elif mask == 'triu': if not self.symmetric: raise ValueError( 'add_weight(w, mask=''triu'') needs a symmetric connectivity matrix') if w.size != np.sum(np.triu(self.w, 1) != 0): raise ValueError( 'number of elements in mask and w do not match') # add weights to upper diagonal matrix if order == 'random': self.w[np.triu(self.w, 1) != 0] = w elif order == 'absrank': # keep absolute rank of weights id_ = np.argsort(np.abs(w)) w = w[id_[::-1]] id_ = np.argsort(np.abs(np.triu(self.w, 1)), axis=None) self.w[np.unravel_index(id_[:-w.size-1:-1], self.w.shape)] = w elif order == 'rank': # keep rank of weights id_ = np.argsort(w) w = w[id_[::-1]] id_ = np.argsort(np.triu(self.w, 1), axis=None) self.w[np.unravel_index(id_[:-w.size-1:-1], self.w.shape)] = w # copy weights to lower diagonal self.w = make_symmetric(self.w, copy_lower=False) def subset_nodes(self, node_set='all', idx_node=None, **kwargs): """ Defines subset of nodes of the connectivity matrix and reduces the connectivity matrix to this subset # TODO """ # get nodes if idx_node is None: idx_node = np.isin(np.arange(self.n_nodes), self.get_nodes(node_set, **kwargs)) # update class attributes self._update_attributes(idx_node) # update component if self.symmetric: self._get_largest_component(self.w) else: self._get_largest_component(np.logical_or(self.w, self.w.T)) warnings.warn("Asymmetric connectivity matrix is only weakly checked for connectedness.") def get_nodes(self, node_set, nodes_from=None, nodes_without=None, filename=None, n_nodes=1, seed=None, **kwargs): """ Gets a set of nodes of the connectivity matrix without changing the connectivity matrix itself # TODO """ # initialize fuller set of nodes we want to select from if nodes_from is None: nodes_from = np.arange(self.n_nodes) if node_set == 'all': # select all nodes without the ones we do not want to select from selected_nodes = np.setdiff1d(nodes_from, nodes_without) elif node_set in ['ctx', 'subctx']: # load cortex and filter to active nodes if filename is not None: ctx = np.load(filename) else: ctx = load_file('cortical.npy') ctx = ctx[self.idx_node] if node_set == 'ctx': # select all nodes in cortex we want to select from selected_nodes = np.where(ctx[nodes_from] == 1)[0] elif node_set == 'subctx': # select all nodes in subcortex we want to select from selected_nodes = np.where(ctx[nodes_from] == 0)[0] # remove nodes we do not want to select from selected_nodes = np.setdiff1d(selected_nodes, nodes_without) elif node_set == 'random': # nodes we want to select from nodes_from = np.setdiff1d(nodes_from, nodes_without) # use random number generator for reproducibility rng = np.random.default_rng(seed=seed) # select random nodes selected_nodes = rng.choice(nodes_from, size=n_nodes, replace=False) elif node_set == 'shortest_path': # calculate shortest paths between all nodes D = distance_bin(self.w) D = np.triu(D) # remove repetitions # nodes we want to select from nodes_from = np.setdiff1d(nodes_from, nodes_without) # shortest paths between all nodes of interest D = D[np.ix_(nodes_from, nodes_from)] # select all node pairs with requested shortest path from each other if isinstance(kwargs['shortest_path'], str): node_pairs = np.argwhere(D == np.amax(D)) elif isinstance(kwargs['shortest_path'], int): node_pairs = np.argwhere(D == kwargs['shortest_path']) # select requested number of nodes from the set above if len(np.unique(node_pairs)) >= n_nodes: i = 1 while len(np.unique(node_pairs[:i, :])) < n_nodes: i += 1 selected_nodes = nodes_from[np.unique(node_pairs[:i, :])] else: raise ValueError( 'n_nodes do not exist with given shortest_path') else: # nodes we want to select from nodes_from = np.setdiff1d(nodes_from, nodes_without) # load resting-state networks and filter to active nodes if filename is not None: rsn_mapping = np.load(filename) else: rsn_mapping = load_file('rsn_mapping.npy') rsn_mapping = rsn_mapping[self.idx_node] # get modules module_ids, modules = get_modules(rsn_mapping) if node_set in module_ids: # select all nodes in the requested module selected_nodes = [e for i, e in enumerate( modules) if (module_ids == node_set)[i]][0] # intersection of nodes we want to select from selected_nodes = np.intersect1d(selected_nodes, nodes_from) else: raise ValueError('node_set does not exist with given value') return selected_nodes def _get_largest_component(self, w): """ Updates a set of nodes so that they belong to one connected component #TODO """ # get all components of the connectivity matrix comps, comp_sizes = get_components(w) # get indexes pertaining to the largest component idx_node = comps == np.argmax(comp_sizes) + 1 # update class attributes self._update_attributes(idx_node) def _update_attributes(self, idx_node): """ Updates network attributes #TODO """ if isinstance(idx_node, np.ndarray) and idx_node.dtype == np.bool: # update node attributes self.n_nodes = sum(idx_node) self.idx_node[self.idx_node] = idx_node # update edge attributes self.w = self.w[np.ix_(idx_node, idx_node)] self.n_edges = np.sum(self.w != 0) # update density self.density = self.n_edges / (self.n_nodes * (self.n_nodes - 1)) else: raise NotImplementedError
def load_file(filename): """ #TODO _summary_ Parameters ---------- filename : _type_ _description_ Returns ------- _type_ _description_ """ return np.load(os.path.join(DATA_DIR, filename)) def get_modules(module_assignment): """ # TODO """ # get module ids module_ids = np.unique(module_assignment) readout_modules = [np.where(module_assignment == i)[0] for i in module_ids] return module_ids, readout_modules def get_readout_nodes(readout_modules): """ Return a list with the set(s) of nodes in each module in 'readout_modules', plus a set of module ids Parameters ---------- readout_modules : (N,) list, tuple, numpy.ndarray or dict Can be a 1D array-like that assigns modules to each node. Can be a list of lists, where each sublist corresponds to the indexes of subsets of nodes. Can be a dictionary key:val pairs, where the keys correspond to modules and the values correspond to list/tuple that contains the subset of nodes in each module. Returns ------- readout_nodes : list list that contains lists with indexes of subsets of nodes in 'readout_modules' ids : list list that contains lists with indexes of subsets of nodes in 'readout_modules' Raises ------ TypeError _description_ """ if isinstance(readout_modules, (list, tuple, np.ndarray)): if all(isinstance(i, (list, tuple, np.ndarray)) for i in readout_modules): ids = list(range(len(readout_modules))) readout_nodes = list(module for module in readout_modules) else: ids = list(set(readout_modules)) readout_nodes = list( np.where(np.array(readout_modules) == i)[0] for i in ids ) elif isinstance(readout_modules, dict): ids = list(readout_modules.keys()) readout_nodes = list(readout_modules.values()) else: raise TypeError("") return readout_nodes, ids