dictionary_learning.py 15.6 KB
Newer Older
Jiri Borovec committed
1 2 3 4 5 6 7
"""
The main module for Atomic pattern disctionary, jioning the atlas estimation
and computing the encoding / weights

Copyright (C) 2015-2016 Jiri Borovec <jiri.borovec@fel.cvut.cz>
"""

Jiri Borovec committed
8
import os
Jiri Borovec committed
9
import logging
Jiri Borovec committed
10

Jiri Borovec committed
11
import numpy as np
Jiri Borovec committed
12 13 14
import matplotlib.pyplot as plt
import skimage.segmentation as sk_image

Jiri Borovec committed
15 16 17
import pattern_disctionary as ptn_dict
import pattern_weights as ptn_weight
import metric_similarity as sim_metric
Jiri Borovec committed
18
import dataset_utils as gen_data
Jiri Borovec committed
19

20 21
UNARY_BACKGROUND = 1
NB_GRAPH_CUT_ITER = 5
Jiri Borovec committed
22

Jiri Borovec committed
23 24 25
# TRY: init: spatial clustering
# TRY: init: use ICA
# TRY: init: greedy
Jiri Borovec committed
26
# Spis by to mozna slo "greedy" metodou
Jiri Borovec committed
27 28 29 30 31 32 33
# prvni obrazek vezmu nahodne, nebo naopak co "nejaktivnejsi" = co nejvetsi
# energie. Pak od vsech ostatnich odectu vhodne zvoleny nasobek a spocitam
# "zbytky". Ze "zbytku"  opet vezmu ten s nejvetsi energii atp. Je mozne
# odecitat algebraicky, nebo logicky, v kazdem pripade bych ponechaval jen
# kladne hodnoty.


Jiri Borovec committed
34
def compute_relative_penaly_images_weights(imgs, weights):
Jiri Borovec committed
35 36 37
    """ compute the relative penalty for all pixel and cjsing each label
    on that particular position

Jiri Borovec committed
38
    :param imgs: [np.array<w, h>]
Jiri Borovec committed
39 40 41
    :param weights: np.array<nb_imgs, nb_lbs>
    :return: np.array<w, h, nb_lbs>
    """
Jiri Borovec committed
42
    logging.debug('compute unary cost from images and related weights')
Jiri Borovec committed
43
    # weightsIdx = ptn_weight.convert_weights_binary2indexes(weights)
Jiri Borovec committed
44
    nb_lbs = weights.shape[1] + 1
Jiri Borovec committed
45 46
    assert len(imgs) == weights.shape[0]
    pott_sum = np.zeros(imgs[0].shape + (nb_lbs,))
Jiri Borovec committed
47 48
    # extenf the weights by background value 0
    weights_ext = np.append(np.zeros((weights.shape[0], 1)), weights, axis=1)
Jiri Borovec committed
49
    # logging.debug(weights_ext)
Jiri Borovec committed
50
    imgs = np.array(imgs)
Jiri Borovec committed
51
    logging.debug('DIMS potts: %s, imgs %s, w_bin: %s', repr(pott_sum.shape),
Jiri Borovec committed
52
                                    repr(imgs.shape), repr(weights_ext.shape))
Jiri Borovec committed
53
    logging.debug('... walk over all pixels in each image')
Jiri Borovec committed
54 55
    for i in range(pott_sum.shape[0]):
        for j in range(pott_sum.shape[1]):
Jiri Borovec committed
56 57 58 59
            # make it as matrix ops
            img_vals = np.repeat(imgs[:, i, j, np.newaxis],
                                 weights_ext.shape[1], axis=1)
            pott_sum[i, j] = np.sum(abs(weights_ext - img_vals), axis=0)
Jiri Borovec committed
60
    pott_sum_norm = pott_sum / float(len(imgs))
Jiri Borovec committed
61 62 63 64
    return pott_sum_norm


def compute_positive_cost_images_weights(imgs, weights):
Jiri Borovec committed
65 66 67 68 69
    """
    :param imgs: [np.array<w, h>]
    :param weights: np.array<nb_imgs, nb_lbs>
    :return: np.array<w, h, nb_lbs>
    """
Jiri Borovec committed
70
    # not using any more...
Jiri Borovec committed
71
    logging.debug('compute unary cost from images and related weights')
Jiri Borovec committed
72
    w_idx = ptn_weight.convert_weights_binary2indexes(weights)
Jiri Borovec committed
73
    nb_lbs = weights.shape[1] + 1
Jiri Borovec committed
74
    assert len(imgs) == len(w_idx)
Jiri Borovec committed
75 76
    pott_sum = np.zeros(imgs[0].shape + (nb_lbs,))
    # walkmover all pixels in image
Jiri Borovec committed
77
    logging.debug('... walk over all pixels in each image')
Jiri Borovec committed
78 79 80 81 82
    for i in range(pott_sum.shape[0]):
        for j in range(pott_sum.shape[1]):
            # per all images in list
            for k in range(len(imgs)):
                # if pixel is active
83
                if imgs[k][i, j] == 1:
Jiri Borovec committed
84
                    # increment all possible spots
Jiri Borovec committed
85 86
                    for x in w_idx[k]:
                        pott_sum[i, j, x] += 1
Jiri Borovec committed
87 88 89
                # else:
                #     graphSum[i,j,0] += 1e-9
            # set also the background values
90
            pott_sum[i, j, 0] = UNARY_BACKGROUND
Jiri Borovec committed
91 92 93 94
    # graph = 1. / (graphSum +1)
    return pott_sum


Jiri Borovec committed
95
def edges_in_image_plane(im_size):
Jiri Borovec committed
96 97 98 99 100
    """ create list of edges for uniform image plane

    :param im_size: (w, h) size of image
    :return: [[int<e1>, int<e1>], ]
    """
Jiri Borovec committed
101 102
    idxs = np.arange(np.product(im_size))
    idxs = idxs.reshape(im_size)
Jiri Borovec committed
103
    # logging.debug(idxs)
104 105
    eA = idxs[:, :-1].ravel().tolist() + idxs[:-1, :].ravel().tolist()
    eB = idxs[:, 1:].ravel().tolist() + idxs[1:, :].ravel().tolist()
Jiri Borovec committed
106
    edges = np.array([eA, eB]).transpose()
Jiri Borovec committed
107
    logging.debug('edges for image plane are shape {}'.format(edges.shape))
Jiri Borovec committed
108 109 110
    return edges


Jiri Borovec committed
111
def estimate_atlas_graphcut_simple(imgs, encoding, coef=1.):
Jiri Borovec committed
112 113 114 115
    """ run the graphcut to estimate atlas from computed unary terms
    source: https://github.com/yujiali/pygco

    :param imgs: [np.array<w, h>] list of input binary images
Jiri Borovec committed
116
    :param encoding: np.array<nb_imgs, nb_lbs> binary ptn selection
Jiri Borovec committed
117 118 119
    :param coef: coefficient for graphcut
    :return:
    """
Jiri Borovec committed
120
    logging.debug('estimate atlas via GraphCut from Potts model')
Jiri Borovec committed
121 122 123
    # source: https://github.com/yujiali/pygco
    from src.wrappers.GraphCut.pygco import cut_grid_graph_simple

Jiri Borovec committed
124
    labeling_sum = compute_positive_cost_images_weights(imgs, encoding)
Jiri Borovec committed
125
    unary_cost = np.array(-1 * labeling_sum , dtype=np.int32)
Jiri Borovec committed
126
    logging.debug('graph unaries potentials %s: \n %s', repr(unary_cost.shape),
Jiri Borovec committed
127
                                        repr(np.histogram(unary_cost, bins=10)))
Jiri Borovec committed
128 129 130
    # original and the right way..
    pairwise = (1 - np.eye(labeling_sum.shape[-1])) * coef
    pairwise_cost = np.array(pairwise , dtype=np.int32)
Jiri Borovec committed
131
    logging.debug('graph pairwise coefs %s', repr(pairwise_cost.shape))
Jiri Borovec committed
132 133 134 135
    # run GraphCut
    labels = cut_grid_graph_simple(unary_cost, pairwise_cost, algorithm='expansion')
    # reshape labels
    labels = labels.reshape(labeling_sum.shape[:2])
Jiri Borovec committed
136
    logging.debug('resulting labelling %s: \n %s', repr(labels.shape), repr(labels))
Jiri Borovec committed
137 138 139
    return labels


140
def estimate_atlas_graphcut_general(imgs, encoding, coef=0., init_atlas=None):
Jiri Borovec committed
141 142 143 144
    """ run the graphcut on the unary costs with specific pairwise cost
    source: https://github.com/yujiali/pygco

    :param imgs: [np.array<w, h>] list of input binary images
Jiri Borovec committed
145
    :param encoding: np.array<nb_imgs, nb_lbs> binary ptn selection
Jiri Borovec committed
146
    :param coef: coefficient for graphcut
Jiri Borovec committed
147
    :param init_labels: np.array<nb_seg, 1> init labeling
Jiri Borovec committed
148 149 150
        while None it take the arg ming of the unary costs
    :return: np.array<nb_seg, 1>
    """
Jiri Borovec committed
151
    logging.debug('estimate atlas via GraphCut from Potts model')
Jiri Borovec committed
152 153 154 155 156 157 158
    # source: https://github.com/yujiali/pygco
    from src.wrappers.GraphCut.pygco import cut_general_graph

    u_cost = compute_relative_penaly_images_weights(imgs, encoding)
    # u_cost = 1. / (labelingSum +1)
    unary_cost = np.array(u_cost , dtype=np.float64)
    unary_cost = unary_cost.reshape(-1, u_cost.shape[-1])
Jiri Borovec committed
159
    logging.debug('graph unaries potentials %s: \n %s', repr(unary_cost.shape),
Jiri Borovec committed
160
                                        repr(np.histogram(unary_cost, bins=10)))
Jiri Borovec committed
161

Jiri Borovec committed
162
    edges = edges_in_image_plane(u_cost.shape[:2])
Jiri Borovec committed
163
    logging.debug('edges for image plane are shape %s', format(edges.shape))
Jiri Borovec committed
164
    edge_weights = np.ones(edges.shape[0])
Jiri Borovec committed
165
    logging.debug('edges weights are shape %s', repr(edge_weights.shape))
Jiri Borovec committed
166 167 168 169

    # original and the right way...
    pairwise = (1 - np.eye(u_cost.shape[-1])) * coef
    pairwise_cost = np.array(pairwise , dtype=np.float64)
Jiri Borovec committed
170
    logging.debug('graph pairwise coefs %s', repr(pairwise_cost.shape))
Jiri Borovec committed
171

172
    if init_atlas is None:
Jiri Borovec committed
173
        init_labels = np.argmin(unary_cost, axis=1)
Jiri Borovec committed
174
    else:
175
        init_labels = init_atlas.ravel()
Jiri Borovec committed
176
    logging.debug('graph initial labels %s', repr(init_labels.shape))
Jiri Borovec committed
177 178 179

    # run GraphCut
    labels = cut_general_graph(edges, edge_weights, unary_cost, pairwise_cost,
180 181
                               algorithm='expansion', init_labels=init_labels,
                               n_iter=NB_GRAPH_CUT_ITER)
Jiri Borovec committed
182 183
    # reshape labels
    labels = labels.reshape(u_cost.shape[:2])
Jiri Borovec committed
184 185
    logging.debug('resulting labelling %s of %s', repr(labels.shape),
                  np.unique(labels).tolist())
Jiri Borovec committed
186 187 188
    return labels


189
def export_visualization_image(img, idx, out_dir, prefix='debug', name='',
Jiri Borovec committed
190
                               ration=None, labels=('', '')):
Jiri Borovec committed
191 192 193
    """ export visualisation as an image with some special desc.

    :param img: np.array<w, h>
194
    :param idx: int, iteration to be shown in the img name
Jiri Borovec committed
195 196 197 198 199 200
    :param out_dir: str, path to the resulting folder
    :param prefix: str
    :param name: str, name of this particular visual
    :param ration: str, mainly for  weights to ne stretched
    :param labels: [str<x>, str<y>] labels for axis
    """
Jiri Borovec committed
201 202
    fig = plt.figure()
    plt.imshow(img, interpolation='none', aspect=ration)
Jiri Borovec committed
203 204
    plt.xlabel(labels[0])
    plt.ylabel(labels[1])
205
    n_fig = 'APDL_{}_{}_iter_{:04d}'.format(prefix, name, idx)
Jiri Borovec committed
206
    p_fig = os.path.join(out_dir, n_fig + '.png')
Jiri Borovec committed
207
    logging.debug('.. export Visualization as "%s...%s"', p_fig[:19], p_fig[-19:])
Jiri Borovec committed
208
    fig.savefig(p_fig, bbox_inches='tight', pad_inches=0.05)
Jiri Borovec committed
209 210 211
    plt.close()


Jiri Borovec committed
212
def export_visual_atlas(i, out_dir, atlas=None, prefix='debug'):
Jiri Borovec committed
213
    """ export the atlas and/or weights to results directory
Jiri Borovec committed
214 215 216 217 218 219 220

    :param i: int, iteration to be shown in the img name
    :param out_dir: str, path to the resulting folder
    :param atlas: np.array<w, h>
    :param weights: np.array<nb_imgs, nb_lbs>
    :param prefix: str
    """
Jiri Borovec committed
221
    if logging.getLogger().getEffectiveLevel()==logging.DEBUG:
Jiri Borovec committed
222
        if not os.path.exists(out_dir):
Jiri Borovec committed
223
            logging.debug('results path "%s" does not exist', out_dir)
Jiri Borovec committed
224 225
            return None
        if atlas is not None:
Jiri Borovec committed
226 227 228 229 230 231 232
            # export_visualization_image(atlas, i, out_dir, prefix, 'atlas',
            #                            labels=['X', 'Y'])
            n_img = 'APDL_{}_atlas_iter_{:04d}'.format(prefix, i)
            gen_data.export_image(out_dir, atlas, n_img)
        # if weights is not None:
        #     export_visualization_image(weights, i, out_dir, prefix, 'weights',
        #                                'auto', ['patterns', 'images'])
Jiri Borovec committed
233 234


235
def apdl_initialisation(imgs, init_atlas, init_weights, out_dir, out_prefix):
Jiri Borovec committed
236 237 238 239 240 241
    """ more complex initialisation depending on inputs

    :param imgs: [np.array<w, h>]
    :param init_atlas: np.array<w, h>
    :param init_weights: np.array<nb_imgs, nb_lbs>
    :param out_prefix: str
Jiri Borovec committed
242
    :param out_dir: str, path to the results directory
Jiri Borovec committed
243 244
    :return: np.array<w, h>, np.array<nb_imgs, nb_lbs>
    """
Jiri Borovec committed
245
    if init_weights is not None and init_atlas is None:
Jiri Borovec committed
246
        logging.debug('... initialise Atlas from w_bins')
Jiri Borovec committed
247
        init_atlas = estimate_atlas_graphcut_general(imgs, init_weights, 0.)
Jiri Borovec committed
248
        export_visual_atlas(0, out_dir, init_atlas, out_prefix)
Jiri Borovec committed
249 250
    if init_atlas is None:
        max_nb_lbs = int(np.sqrt(len(imgs)))
Jiri Borovec committed
251
        logging.debug('... initialise Atlas with ')
Jiri Borovec committed
252
        # IDEA: find better way of initialisation
Jiri Borovec committed
253
        init_atlas = ptn_dict.initialise_atlas_mosaic(imgs[0].shape, max_nb_lbs)
Jiri Borovec committed
254
        export_visual_atlas(0, out_dir, init_atlas, out_prefix)
Jiri Borovec committed
255 256 257 258

    atlas = init_atlas
    w_bins = init_weights
    if len(np.unique(atlas)) == 1:
259 260
        logging.error('the init. atlas does not contain any label... %s',
                      repr(np.unique(atlas)))
Jiri Borovec committed
261
    export_visual_atlas(0, out_dir, atlas, out_prefix)
Jiri Borovec committed
262 263 264
    return atlas, w_bins


265
def apdl_update_weights(imgs, atlas, overlap_major=False):
Jiri Borovec committed
266 267 268 269 270 271 272
    """ single iteration of the block coordinate descent algo

    :param imgs: [np.array<w, h>]
    :param atlas: np.array<w, h>
    :return: np.array<nb_imgs, nb_lbs>
    """
    # update w_bins
273
    logging.debug('... perform pattern weights')
Jiri Borovec committed
274
    if overlap_major:
Jiri Borovec committed
275 276 277 278 279
        w_bins = [ptn_weight.weights_image_atlas_overlap_major(img, atlas)
                  for img in imgs]
    else:
        w_bins = [ptn_weight.weights_image_atlas_overlap_partial(img, atlas)
                  for img in imgs]
Jiri Borovec committed
280 281
    # add once for patterns that are not used at all
    # w_bins = ptn_weight.fill_empty_patterns(np.array(w_bins))
Jiri Borovec committed
282
    return np.array(w_bins)
Jiri Borovec committed
283 284


285
def apdl_update_atlas(imgs, atlas, w_bins, label_max, gc_coef, gc_reinit, ptn_split):
Jiri Borovec committed
286 287 288 289 290
    """ single iteration of the block coordinate descent algo

    :param imgs: [np.array<w, h>]
    :param atlas: np.array<w, h>
    :param w_bins: np.array<nb_imgs, nb_lbs>
Jiri Borovec committed
291
    :param label_max: int
Jiri Borovec committed
292 293
    :param gc_coef: float, graph cut regularisation
    :param gc_reinit: bool, weather use atlas from previous step as init for act.
Jiri Borovec committed
294
    :param ptn_split: bool
Jiri Borovec committed
295
    :return: np.array<w, h>
Jiri Borovec committed
296 297
    """
    if np.sum(w_bins) == 0:
Jiri Borovec committed
298
        logging.warning('the w_bins is empty... %s', repr(np.unique(atlas)))
Jiri Borovec committed
299
    w_bins = np.array(w_bins)
Jiri Borovec committed
300

Jiri Borovec committed
301
    logging.debug('... perform Atlas estimation')
Jiri Borovec committed
302
    if gc_reinit:
Jiri Borovec committed
303
        atlas_new = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef, atlas)
Jiri Borovec committed
304
    else:
Jiri Borovec committed
305 306
        atlas_new = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef)

Jiri Borovec committed
307
    if ptn_split:
Jiri Borovec committed
308 309
        atlas_new = ptn_dict.atlas_split_indep_ptn(atlas_new, label_max)
    return atlas_new
Jiri Borovec committed
310 311


312 313
def apdl_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None,
                                         gc_coef=0.0, tol=1e-3, max_iter=25,
Jiri Borovec committed
314
                                         gc_reinit=True, ptn_split=True,
315
                                         overlap_major=False, ptn_compact=True,
Jiri Borovec committed
316
                                         out_prefix='debug', out_dir=''):
Jiri Borovec committed
317
    """ the experiments_synthetic pipeline for block coordinate descent
Jiri Borovec committed
318
    algo with graphcut...
Jiri Borovec committed
319 320 321 322 323

    :param imgs: [np.array<w, h>]
    :param init_atlas: np.array<w, h>
    :param init_weights: np.array<nb_imgs, nb_lbs>
    :param gc_coef: float, graph cut regularisation
Jiri Borovec committed
324
    :param tol: float, stop if the diff between two conseq steps
Jiri Borovec committed
325 326 327 328
        is less then this given threshold. eg for -1 never until max nb iters
    :param max_iter: int, max namber of iteration
    :param gc_reinit: bool, wether use atlas from previous step as init for act.
    :param out_prefix: str
Jiri Borovec committed
329
    :param out_dir: str, path to the results directory
Jiri Borovec committed
330 331
    :return: np.array<w, h>, np.array<nb_imgs, nb_lbs>
    """
Jiri Borovec committed
332
    logging.debug('compute an Atlas and weights for %i images...', len(imgs))
Jiri Borovec committed
333
    assert len(imgs) >= 0
Jiri Borovec committed
334 335 336
    if logging.getLogger().getEffectiveLevel()==logging.DEBUG:
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
Jiri Borovec committed
337 338
    # assert initAtlas is not None or type(max_nb_lbs)==int
    # initialise
339 340 341
    label_max = np.max(init_atlas)
    logging.debug('max nb labels set: %i', label_max)
    atlas, w_bins = apdl_initialisation(imgs, init_atlas, init_weights,
Jiri Borovec committed
342
                                        out_dir, out_prefix)
Jiri Borovec committed
343
    list_crit = []
Jiri Borovec committed
344

345
    for iter in range(max_iter):
Jiri Borovec committed
346
        if len(np.unique(atlas)) == 1:
347 348 349 350 351 352
            logging.warning('.. iter: %i, no labels in the atlas %s', iter,
                            repr(np.unique(atlas).tolist()))
        w_bins = apdl_update_weights(imgs, atlas, overlap_major)
        atlas_reinit, w_bins = ptn_dict.reinit_atlas_likely_patterns(imgs,
                                         w_bins, atlas, label_max, ptn_compact)
        atlas_new = apdl_update_atlas(imgs, atlas_reinit, w_bins, label_max,
Jiri Borovec committed
353 354 355
                                      gc_coef, gc_reinit, ptn_split)

        step_diff = sim_metric.compare_atlas_adjusted_rand(atlas, atlas_new)
356
        # step_diff = np.sum(abs(atlas - atlas_new)) / float(np.product(atlas.shape))
Jiri Borovec committed
357
        list_crit.append(step_diff)
358
        atlas = sk_image.relabel_sequential(atlas_new)[0]
Jiri Borovec committed
359

360 361
        logging.debug('-> iter. #%i with Atlas diff %f', (iter + 1), step_diff)
        export_visual_atlas(iter + 1, out_dir, atlas, out_prefix)
Jiri Borovec committed
362

363
        # stopping criterion
364
        if step_diff <= tol and len(np.unique(atlas)) > 1:
Jiri Borovec committed
365
            logging.debug('>> exit while the atlas diff %f is smaller then %f',
Jiri Borovec committed
366
                          step_diff, tol)
Jiri Borovec committed
367
            break
368 369
    logging.info('APDL: terminated with iter %i / %i and step diff %f <? %f',
                 iter, max_iter, step_diff, tol)
Jiri Borovec committed
370
    logging.debug('criterion evolved:\n %s', repr(list_crit))
371
    # atlas = sk_image.relabel_sequential(atlas)[0]
Jiri Borovec committed
372
    w_bins = [ptn_weight.weights_image_atlas_overlap_major(img, atlas) for img in imgs]
Jiri Borovec committed
373
    return atlas, np.array(w_bins)