Commit 51c4b1c8 authored by Jiri Borovec's avatar Jiri Borovec

fix APDL stop crit

parent ee935371
......@@ -136,14 +136,14 @@ def estimate_atlas_graphcut_simple(imgs, encoding, coef=1.):
return labels
def estimate_atlas_graphcut_general(imgs, encoding, coef=1., init_lbs=None):
def estimate_atlas_graphcut_general(imgs, encoding, coef=1., init_labels=None):
""" run the graphcut on the unary costs with specific pairwise cost
:param imgs: [np.array<w, h>] list of input binary images
:param encoding: np.array<nb_imgs, nb_lbs> binary ptn selection
:param coef: coefficient for graphcut
:param init_lbs: np.array<nb_seg, 1> init labeling
:param init_labels: np.array<nb_seg, 1> init labeling
while None it take the arg ming of the unary costs
:return: np.array<nb_seg, 1>
......@@ -168,18 +168,19 @@ def estimate_atlas_graphcut_general(imgs, encoding, coef=1., init_lbs=None):
pairwise_cost = np.array(pairwise , dtype=np.float64)
logging.debug('graph pairwise coefs %s', repr(pairwise_cost.shape))
if init_lbs is None:
init_lbs = np.argmin(unary_cost, axis=1)
if init_labels is None:
init_labels = np.argmin(unary_cost, axis=1)
init_lbs = init_lbs.ravel()
logging.debug('graph initial labels %s', repr(init_lbs.shape))
init_labels = init_labels.ravel()
logging.debug('graph initial labels %s', repr(init_labels.shape))
# run GraphCut
labels = cut_general_graph(edges, edge_weights, unary_cost, pairwise_cost,
algorithm='expansion', init_labels=init_lbs)
algorithm='expansion', init_labels=init_labels)
# reshape labels
labels = labels.reshape(u_cost.shape[:2])
logging.debug('resulting labelling %s', repr(labels.shape))
logging.debug('resulting labelling %s of %s', repr(labels.shape),
return labels
......@@ -290,43 +291,41 @@ def alpe_repaire_atlas_weights(imgs, atlas, w_bins, label_max):
logging.debug('... perform repairing atlas & weights')
# reinit empty
atlas, w_bins = ptn_dict.reinit_atlas_likely_patterns(imgs, w_bins, atlas,
return atlas, w_bins
def alpe_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit, ptn_split):
def alpe_update_atlas(imgs, atlas, w_bins, label_max, gc_coef, gc_reinit, ptn_split):
""" single iteration of the block coordinate descent algo
:param imgs: [np.array<w, h>]
:param atlas: np.array<w, h>
:param w_bins: np.array<nb_imgs, nb_lbs>
:param lb_max: int
:param label_max: int
:param gc_coef: float, graph cut regularisation
:param gc_reinit: bool, weather use atlas from previous step as init for act.
:param ptn_split: bool
:return: np.array<w, h>, float
:return: np.array<w, h>
if np.sum(w_bins) == 0:
logging.warning('the w_bins is empty... %s', repr(np.unique(atlas)))
w_bins = np.array(w_bins)
# update atlas
logging.debug('... perform Atlas estimation')
# atlas_new = estimate_atlas_graphcut_simple(imgs, w_bins)
if gc_reinit:
atlas_new = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef, atlas)
atlas_new = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef)
if ptn_split:
atlas_new = ptn_dict.atlas_split_indep_ptn(atlas_new, lb_max)
step_diff = sim_metric.compare_atlas_adjusted_rand(atlas, atlas_new)
return atlas_new, step_diff
atlas_new = ptn_dict.atlas_split_indep_ptn(atlas_new, label_max)
return atlas_new
def alpe_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None,
gc_coef=0.0, tol=1e-5, max_iter=150,
gc_reinit=True, ptn_split=True, overlap_major=False,
gc_coef=0.0, tol=1e-4, max_iter=25,
gc_reinit=True, ptn_split=True,
out_prefix='debug', out_dir=''):
""" the experiments_synthetic pipeline for block coordinate descent
algo with graphcut...
......@@ -352,30 +351,33 @@ def alpe_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=Non
# initialise
atlas, w_bins = alpe_initialisation(imgs, init_atlas, init_weights,
out_dir, out_prefix)
lb_max = np.max(atlas)
label_max = np.max(atlas)
list_crit = []
for i in range(max_iter):
if len(np.unique(atlas)) == 1:
logging.warning('the atlas does not contain any label... %s',
logging.warning('the atlas does not contain any label... %i',
w_bins = alpe_update_weights(imgs, atlas, overlap_major)
# plt.subplot(221), plt.imshow(atlas, interpolation='nearest')
# plt.subplot(222), plt.imshow(w_bins, aspect='auto')
atlas, w_bins = alpe_repaire_atlas_weights(imgs, atlas, w_bins, lb_max)
atlas_new, w_bins = alpe_repaire_atlas_weights(imgs, atlas, w_bins, label_max)
# plt.subplot(223), plt.imshow(atlas, interpolation='nearest')
# plt.subplot(224), plt.imshow(w_bins, aspect='auto')
atlas, step_diff = alpe_update_atlas(imgs, atlas, w_bins, lb_max,
gc_coef, gc_reinit, ptn_split)
atlas_new = alpe_update_atlas(imgs, atlas_new, w_bins, label_max,
gc_coef, gc_reinit, ptn_split)
step_diff = sim_metric.compare_atlas_adjusted_rand(atlas, atlas_new)
atlas = atlas_new
logging.debug('-> iter. #%i with Atlas diff %f', (i + 1), step_diff)
export_visual_atlas(i + 1, out_dir, atlas, out_prefix)
# stopping criterion
if step_diff <= tol:
logging.debug('>> exiting while the atlas diff %f is smaller then %f',
logging.debug('>> exit while the atlas diff %f is smaller then %f',
step_diff, tol)
break'...terminated with %i / %i iter and step diff %f <? %f',
......@@ -63,7 +63,7 @@ def compare_atlas_adjusted_rand(a1, a2):
assert np.array_equal(a1.shape, a2.shape)
ars = metrics.adjusted_rand_score(a1.ravel(), a2.ravel())
res = 0.5 - (ars / 2.)
res = 1 - abs(ars)
return res
......@@ -163,20 +163,21 @@ def reinit_atlas_likely_patterns(imgs, w_bins, atlas, label_max=None):
logging.debug('IN > sum over weights: %s', repr(np.sum(w_bin_ext, axis=0)))
# add one while indexes does not cover 0 - bg
logging.debug('total nb labels: %i', label_max)
atlas_new = atlas.copy()
for lb in range(1, label_max + 1):
w_index = lb - 1
w_sum = np.sum(w_bins[:, w_index])
logging.debug('reinit lb: %i with weight sum %i', lb, w_sum)
if w_sum > 0:
imgs_rc = reconstruct_samples(atlas, w_bins)
atlas = insert_new_pattern(imgs, imgs_rc, atlas, lb)
imgs_rc = reconstruct_samples(atlas_new, w_bins)
atlas_new = insert_new_pattern(imgs, imgs_rc, atlas_new, lb)
logging.debug('w_bins before: %i', np.sum(w_bins[:, w_index]))
lim_repopulate = 100. /
lim_repopulate = 100. /
w_bins[:, w_index] = ptn_weight.weights_label_atlas_overlap_threshold(imgs,
atlas, lb, lim_repopulate)
atlas_new, lb, lim_repopulate)
logging.debug('w_bins after: %i', np.sum(w_bins[:, w_index]))
return atlas, w_bins
return atlas_new, w_bins
def atlas_split_indep_ptn(atlas, lb_max):
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment