probinet.model_selection.cross_validation#

Main function to implement cross-validation given a number of communities.

  • Hold-out part of the dataset (pairs of edges labeled by unordered pairs (i,j));

  • Infer parameters on the training set;

  • Calculate performance measures in the test set (AUC).

Classes

CrossValidation(algorithm, model_parameters, ...)

Abstract class to implement cross-validation for a given algorithm.

class probinet.model_selection.cross_validation.CrossValidation(algorithm, model_parameters, cv_parameters, numerical_parameters=None)[source]#

Abstract class to implement cross-validation for a given algorithm.

static define_grid(**kwargs)[source]#

Define the grid of parameters to be tested.

abstract extract_mask(fold)[source]#

Extract the mask for the current fold.

load_data()[source]#

Auxiliary method to load data from the input folder.

prepare_and_run(mask: ndarray)[source]#

Prepare the data for training and run the algorithm.

Parameters:

mask (np.ndarray) – The mask to apply on the data.

Returns:

  • tuple – The outputs of the algorithm.

  • object – The algorithm object.

prepare_file_name()[source]#
prepare_output_directory()[source]#

Prepare the output directory to save the results.

run_single_iteration()[source]#

Run the cross-validation procedure.

save_results()[source]#