run_apd_hierarchical_cluster.py 9.96 KB
Newer Older
Jiri Borovec committed
1 2 3 4 5 6 7 8
"""
Perform the hierarchical clustering on estimated atlas such as merging patterns
together with the smallest reconstruction error.
We assume going from all patters presented to single pattern in the atlas

EXAMPLE:
>> python run_apd_hierarchical_cluster.py \
    --path_in /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APDL_real
9 10 11 12

>> python run_apd_hierarchical_cluster.py \
    --path_in ~/Medical-data/microscopy/drosophila/RESULTS/experiments_APD_real

Jiri Borovec committed
13 14 15 16 17 18 19 20 21 22 23 24
>> python run_apd_hierarchical_cluster.py \
    --path_in /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APD_temp \
    --names_expt ExperimentALPE_mp_real_type_3_segm_reg_binary_gene_ssmall_20160509-155333 \
    --nb_jobs 1

Copyright (C) 2015-2016 Jiri Borovec <jiri.borovec@fel.cvut.cz>
"""

import os
import glob
import logging
import gc
Jiri Borovec committed
25
import sys
Jiri Borovec committed
26 27 28 29 30 31 32
import time
import multiprocessing as mproc
from functools import partial

import tqdm
import numpy as np
import pandas as pd
33
import matplotlib.pyplot as plt
Jiri Borovec committed
34 35 36

import dataset_utils as gen_data
import run_apd_reconstruction as r_reconst
Jiri Borovec committed
37 38
sys.path.append(os.path.abspath(os.path.join('..','..')))  # Add path to root
import src.segmentation.tool_superpixels as tl_spx
Jiri Borovec committed
39

Jiri Borovec committed
40
CONNECT_PATTERNS = True
Jiri Borovec committed
41 42 43 44 45 46
NB_THREADS = int(mproc.cpu_count() * .9)
NAME_CONFIG = 'config.json'
PREFIX_ATLAS = 'atlas_'
PREFIX_ENCODE = 'encoding_'
PREFIX_RECONST = 'reconstruct_'
CSV_RECONT_DIFF = 'reconstruct_hierarchical_clustering.csv'
47
FIG_RECONT_DIFF = 'reconstruct_hierarchical_clustering.pdf'
Jiri Borovec committed
48 49
POSIX_CSV_SKIP = r_reconst.POSIX_CSV_NEW
DIR_PREFIX = 'hierarchical_clustering_'
50
POSIX_MERGED = 'merged_nb_labels_%i'
Jiri Borovec committed
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76


def compute_merged_reconst_diff(ptn_comb, dict_params, path_out, atlas, img_names):
    """ merge pattern pair and compute reconstruction diff

    :param ptn_comb: (int, int)
    :param dict_params: {str: int}
    :param path_out: str
    :param atlas: np.array<height, width>
    :param img_names: [str]
    :return: (int, int), float
    """
    atlas_merge = atlas.copy()
    # merge pattern pair
    for lb in ptn_comb[1:]:
        atlas_merge[atlas_merge == lb] = ptn_comb[0]
    wrapper_reconstruction = partial(r_reconst.compute_reconstruction,
                                     dict_params=dict_params, path_out=path_out,
                                     im_atlas=atlas_merge)
    tuples_name_diff = map(wrapper_reconstruction, img_names)
    # compute mean diff over all reconst
    diff = np.mean(np.asarray(tuples_name_diff)[:, 1].astype(np.float))
    return ptn_comb, diff


def hierarchical_clustering_merge_patterns(dict_params, path_out, img_names,
Jiri Borovec committed
77
                                           atlas, nb_jobs=NB_THREADS, connect_ptns=CONNECT_PATTERNS):
Jiri Borovec committed
78 79 80 81 82 83 84 85 86 87 88
    """ using hierarchical clustering merge pattern pair and return partial results

    :param dict_params: {str: ...}
    :param path_out: str
    :param img_names: [str]
    :param atlas: np.array<height, width>
    :param nb_jobs: int
    :return: np.array<height, width>, (int, int), float
    """
    labels = sorted(np.unique(atlas).tolist())
    # generate combinations as list as skipping the 0 assuming on first position
Jiri Borovec committed
89 90 91 92 93
    if connect_ptns:
        _, ptn_combines = tl_spx.make_graph_segm_connect2d_conn4(atlas)
    else:
        ptn_combines = [(labels[i], labels[j])
                        for i in range(1, len(labels)) for j in range(1, i)]
Jiri Borovec committed
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    assert len(ptn_combines) > 0 and \
           not any(len(set(ptn)) == 1 for ptn in ptn_combines)
    # parallel compute reconstructions
    wrapper_compute_merged = partial(compute_merged_reconst_diff,
                                     dict_params=dict_params, path_out=path_out,
                                     atlas=atlas, img_names=img_names)
    if nb_jobs > 1:
        mproc_pool = mproc.Pool(nb_jobs)
        tuples_ptn_diff = mproc_pool.map(wrapper_compute_merged, ptn_combines)
        mproc_pool.close()
        mproc_pool.join()
    else:
        tuples_ptn_diff = map(wrapper_compute_merged, ptn_combines)
    logging.debug('computed merged diffs: %s', repr(tuples_ptn_diff))
    idx_min = np.argmin(tuples_ptn_diff, axis=0)[1]
    ptn_comb, diff = tuples_ptn_diff[idx_min]
    logging.debug('found minimal pn pos %i for diff %f and patterns %s',
                  idx_min, diff, repr(ptn_comb))
    atlas_merged = atlas.copy()
    for lb in ptn_comb[1:]:
        atlas_merged[atlas_merged == lb] = ptn_comb[0]
    return atlas_merged, ptn_comb, diff


def export_partial_atlas_encode(dict_params, path_out, df_merged, max_label,
                                nb, atlas, ptn_comb, diff):
    """ export partial results such as atlas, encoding and reconstruct diff

    :param dict_params: {str: ...}
    :param path_out: str
    :param df_merged: DF
    :param max_label: int
    :param nb: int
    :param atlas: np.array<height, width>
    :param ptn_comb: (int, int)
    :param diff: float
    :return: DF
    """
132
    gen_data.export_image(path_out, atlas, PREFIX_ATLAS + POSIX_MERGED % nb)
Jiri Borovec committed
133
    r_reconst.export_fig_atlas(atlas, path_out,
134
                               PREFIX_ATLAS + POSIX_MERGED % nb, max_label)
Jiri Borovec committed
135
    df_encode = r_reconst.recompute_encoding(dict_params, atlas)
136
    df_encode.to_csv(os.path.join(path_out, PREFIX_ENCODE + POSIX_MERGED % nb + '.csv'))
Jiri Borovec committed
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    df_merged = df_merged.append({
        'nb_labels': nb,
        'merged': ptn_comb,
        'reconst_diff': diff}, ignore_index=True)
    return df_merged


def sequence_hierarchical_clustering(dict_params, path_out, img_names, atlas,
                                     nb_jobs=NB_THREADS):
    """ sequance if hierarchical clustering which decrease number of patterns
    by partial merging pattern pairs and exporting partial results

    :param dict_params: {str, ...}
    :param path_out: str
    :param img_names: [str]
    :param atlas: np.array<height, width>
    :param nb_jobs:
    :return: DF
    """
    if not os.path.exists(path_out):
        os.mkdir(path_out)
    nb_labels = len(np.unique(atlas))
    max_label = atlas.max()
    df_merged = pd.DataFrame()
    ptn_comb, diff = compute_merged_reconst_diff((0, 0), dict_params, path_out,
                                                 atlas, img_names)
    df_merged = export_partial_atlas_encode(dict_params, path_out, df_merged,
                                            max_label, nb_labels, atlas, ptn_comb, diff)
    # recursively merge patterns
    for nb in reversed(range(2, nb_labels)):
        atlas, ptn_comb, diff = hierarchical_clustering_merge_patterns(
                            dict_params, path_out, img_names, atlas, nb_jobs)
        df_merged = export_partial_atlas_encode(dict_params, path_out, df_merged,
                                                max_label, nb, atlas, ptn_comb, diff)
171
        df_merged.to_csv(os.path.join(path_out, CSV_RECONT_DIFF))
Jiri Borovec committed
172 173 174 175 176
    df_merged.set_index('nb_labels', inplace=True)
    df_merged.to_csv(os.path.join(path_out, CSV_RECONT_DIFF))
    return df_merged


177
def plot_reconst_diff(path_csv):
Jiri Borovec committed
178
    """ plot the reconst. diff from PD and expoert to a figure
179 180 181 182 183 184 185 186

    :param path_csv: str
    """
    assert os.path.exists(path_csv)
    df_diff = pd.DataFrame.from_csv(path_csv)
    fig = plt.figure()
    df_diff.plot(ax=fig.gca(), title='Reconstruction diff. for estimated atlases')
    fig.gca().set_ylabel('reconst. diff [%]')
Jiri Borovec committed
187
    fig.gca().grid()
188 189 190 191
    path_fig = os.path.join(os.path.dirname(path_csv), FIG_RECONT_DIFF)
    fig.savefig(path_fig)


Jiri Borovec committed
192 193 194
def process_experiment(path_expt, nb_jobs=NB_THREADS):
    """ process complete folder with experiment

Jiri Borovec committed
195
    :param nb_jobs: int
Jiri Borovec committed
196 197 198 199 200 201 202 203 204 205 206 207 208
    :param path_expt: str
    """
    logging.info('Experiment folder: \n "%s"', path_expt)
    dict_params = r_reconst.load_config_json(path_expt)
    atlas_names = [os.path.basename(p) for p
                   in glob.glob(os.path.join(path_expt, PREFIX_ATLAS + '*.png'))]
    list_csv = [p for p in glob.glob(os.path.join(path_expt, PREFIX_ENCODE + '*.csv'))
                if not p.endswith(POSIX_CSV_SKIP)]
    logging.debug('found %i CSV files: %s', len(list_csv), repr(list_csv))
    df_diffs_all = pd.DataFrame()
    for path_csv in sorted(list_csv):
        name_csv = os.path.basename(path_csv)
        name_atlas = r_reconst.find_relevant_atlas(name_csv, atlas_names)
209 210 211 212
        if name_atlas is None:
            logging.warning('nor related atlas for particular csv encoding "%s"',
                            name_csv)
            continue
Jiri Borovec committed
213 214 215 216 217 218 219 220 221 222 223 224
        logging.info('Atlas: "%s" -> Encoding: "%s"', name_atlas, name_csv)
        path_atlas = os.path.join(path_expt, name_atlas)
        atlas = r_reconst.load_atlas_image(path_atlas)
        img_names = pd.DataFrame.from_csv(path_csv).index.tolist()
        path_out = os.path.join(path_expt, DIR_PREFIX + os.path.splitext(name_atlas)[0])
        df_diff = sequence_hierarchical_clustering(dict_params, path_out,
                                                   img_names, atlas, nb_jobs)
        # separet jut the recont dif and name it after atlas
        df_diff = df_diff['reconst_diff']
        df_diff.name = os.path.splitext(name_atlas)[0]
        logging.debug('records: %i for "%s"', len(df_diff), df_diff.name)
        df_diffs_all = pd.concat([df_diffs_all, df_diff], axis=1)
225 226
        df_diffs_all.to_csv(os.path.join(path_expt, CSV_RECONT_DIFF))
    logging.info('processed files: %s', repr(df_diffs_all.columns))
Jiri Borovec committed
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242


def main():
    """ process complete list of experiments """
    logging.basicConfig(level=logging.INFO)
    logging.info('running...')

    arg_params = r_reconst.parse_arg_params(r_reconst.create_args_parser())
    logging.info('PARAMS: \n%s', '\n'.join(['"{}": \n\t {}'.format(k, v)
                                            for k, v in arg_params.iteritems()]))
    list_expt = [os.path.join(arg_params['path_in'], n) for n in arg_params['names_expt']]
    assert len(list_expt) > 0, 'No experiments found!'

    tqdm_bar = tqdm.tqdm(total=len(list_expt))
    for path_expt in list_expt:
        process_experiment(path_expt, arg_params['nb_jobs'])
243
        plot_reconst_diff(os.path.join(path_expt, CSV_RECONT_DIFF))
Jiri Borovec committed
244 245 246 247 248 249 250 251
        gc.collect(), time.sleep(1)
        tqdm_bar.update(1)

    logging.info('DONE')


if __name__ == '__main__':
    main()