Commit 428f0c88 authored by Jiri Borovec's avatar Jiri Borovec

new inits with Gaussian/Otsu and Watershed

parent f62b4de8
......@@ -158,7 +158,7 @@ def extract_image_largest_element(img_binary, labeled=None):
labeled, _ = ndimage.label(img_binary)
areas = [(j, np.sum(labeled == j)) for j in np.unique(labeled)]
areas = sorted(areas, key=lambda x: x[1], reverse=True)
logger.debug('... elements area: %s', repr(areas))
# logger.debug('... elements area: %s', repr(areas))
img_ptn = img_binary.copy()
if len(areas) > 1:
img_ptn = np.zeros_like(img_binary)
......
......@@ -7,7 +7,7 @@ Copyright (C) 2015-2016 Jiri Borovec <jiri.borovec@fel.cvut.cz>
import logging
from scipy import ndimage
from skimage import morphology, feature
from skimage import morphology, feature, filters
from scipy import ndimage as ndi
import numpy as np
......@@ -29,14 +29,14 @@ def initialise_atlas_random(im_size, label_max):
:return: np.array<w, h>
"""
logging.debug('initialise atlas %s as random labeling', repr(im_size))
nb_lbs = label_max + 1
nb_labels = label_max + 1
np.random.seed() # reinit seed to have random samples even in the same time
im = np.random.randint(1, nb_lbs, im_size)
return np.array(im, dtype=np.int)
img_init = np.random.randint(1, nb_labels, im_size)
return np.array(img_init, dtype=np.int)
def initialise_atlas_mosaic(im_size, nb_labels, coef=1.):
""" generate grids trusture and into each rectangle plase a label,
""" generate grids texture and into each rectangle plase a label,
each row contains all labels (permutation)
:param im_size: (w, h) size og image
......@@ -62,9 +62,61 @@ def initialise_atlas_mosaic(im_size, nb_labels, coef=1.):
else: mosaic = np.vstack((mosaic, row))
logging.debug('generated mosaic %s with labeling %s',
repr(mosaic.shape), repr(np.unique(mosaic).tolist()))
im = mosaic[:im_size[0], :im_size[1]]
im = np.remainder(im, nb_labels)
return np.array(im, dtype=np.int)
img_init = mosaic[:im_size[0], :im_size[1]]
img_init = np.remainder(img_init, nb_labels)
return np.array(img_init, dtype=np.int)
def initialise_atlas_otsu_watershed_2d(imgs, nb_labels, bg='none'):
""" do some simple operations to get better initialisation
1] sum over all images, 2] Otsu thresholding, 3] watershed
:param imgs: [np.array<w, h>]
:param nb_labels: int
:param bg: str, set weather the Otsu backround sould be filled randomly
:return: np.array<w, h>
"""
logging.debug('initialise atlas for %i labels from %i images of shape %s '
'with Otsu-Watershed', nb_labels, len(imgs), repr(imgs[0].shape))
img_sum = np.sum(np.asarray(imgs), axis=0) / float(len(imgs))
img_gauss = filters.gaussian_filter(img_sum, 1)
# http://scikit-image.org/docs/dev/auto_examples/plot_otsu.html
thresh = filters.threshold_otsu(img_gauss)
img_otsu = (img_gauss >= thresh)
# http://scikit-image.org/docs/dev/auto_examples/plot_watershed.html
img_dist = ndi.distance_transform_edt(img_otsu)
local_maxi = feature.peak_local_max(img_dist, labels=img_otsu,
footprint=np.ones((2, 2)))
seeds = np.zeros_like(img_sum)
seeds[local_maxi[:,0], local_maxi[:,1]] = range(1, len(local_maxi) + 1)
labels = morphology.watershed(-img_dist, seeds)
img_init = np.remainder(labels, nb_labels)
if bg == 'rand':
# add random labels on the potential backgound
img_rand = np.random.randint(1, nb_labels, img_sum.shape)
img_init[img_otsu == 0] = img_rand[img_otsu == 0]
return img_init.astype(np.int)
def initialise_atlas_gauss_watershed_2d(imgs, nb_labels):
""" do some simple operations to get better initialisation
1] sum over all images, 2]watershed
:param imgs: [np.array<w, h>]
:param nb_labels: int
:return: np.array<w, h>
"""
logging.debug('initialise atlas for %i labels from %i images of shape %s '
'with Gauss-Watershed', nb_labels, len(imgs), repr(imgs[0].shape))
img_sum = np.sum(np.asarray(imgs), axis=0) / float(len(imgs))
img_gauss = filters.gaussian_filter(img_sum, 1)
local_maxi = feature.peak_local_max(img_gauss, footprint=np.ones((2, 2)))
seeds = np.zeros_like(img_sum)
seeds[local_maxi[:,0], local_maxi[:,1]] = range(1, len(local_maxi) + 1)
# http://scikit-image.org/docs/dev/auto_examples/plot_watershed.html
labels = morphology.watershed(-img_gauss, seeds) # , mask=im_diff
img_init = np.remainder(labels, nb_labels)
return img_init.astype(np.int)
def initialise_atlas_deform_original(atlas):
......@@ -127,7 +179,7 @@ def prototype_new_pattern(imgs, imgs_reconst, diffs, atlas,
# if ptn_size < 0.01:
# logging.debug('new patterns was too small %f', ptn_size)
# ptn = data.extract_image_largest_element(im_diff)
img_ptn = np.logical_and(img_ptn == True)
img_ptn = (img_ptn == True)
# img_ptn = np.logical_and(img_ptn == True, atlas == 0)
return img_ptn
......@@ -148,7 +200,7 @@ def insert_new_pattern(imgs, imgs_reconst, atlas, label,
# logging.debug('new im_ptn: {}'.format(np.sum(im_ptn) / np.prod(im_ptn.shape)))
# plt.imshow(im_ptn), plt.title('im_ptn'), plt.show()
atlas[im_ptn == True] = label
logging.debug('area of new pattern %i is %i', im_ptn, np.sum(atlas == label))
logging.debug('area of new pattern %i is %i', label, np.sum(atlas == label))
return atlas
......
......@@ -52,17 +52,19 @@ DICT_ATLAS_INIT = {
'msc1': partial(ptn_dict.initialise_atlas_mosaic, coef=1.5),
'msc2': partial(ptn_dict.initialise_atlas_mosaic, coef=2),
'rnd': ptn_dict.initialise_atlas_random,
'OWS': ptn_dict.initialise_atlas_otsu_watershed_2d,
'OWSr': partial(ptn_dict.initialise_atlas_otsu_watershed_2d, bg='rand'),
'GWS': ptn_dict.initialise_atlas_gauss_watershed_2d,
'GT': None, # init by Ground Truth, require GT atlas
'GTd': None, # init by deformed Ground Truth, require GT atlas
# 'OWS': None,
}
# SIMPLE RUN
INIT_TYPES = ['rnd', 'msc', 'msc2']
# GRAPHCUT_REGUL = [1e-2, 1e-1, 0.]
INIT_TYPES = ['OWS', 'OWSr', 'GWS']
GRAPHCUT_REGUL = [0., 1e-9, 1e-3]
# COMPLEX RUN
# INIT_TYPES = DICT_ATLAS_INIT.keys()
GRAPHCUT_REGUL = [0., 0e-12, 1e-9, 1e-6, 1e-3, 1e-1]
# GRAPHCUT_REGUL = [0., 0e-12, 1e-9, 1e-6, 1e-3, 1e-1]
def test_simple_show_case():
......@@ -130,7 +132,7 @@ class ExperimentAPDL_base(expt_apd.ExperimentAPD):
the main real experiment or our Atlas Learning Pattern Encoding
"""
def _init_atlas(self, nb_labels, init_tp):
def _init_atlas(self, nb_labels, init_tp, imgs):
""" init atlas according an param
:param nb_lbs: int
......@@ -138,9 +140,13 @@ class ExperimentAPDL_base(expt_apd.ExperimentAPD):
:return: np.array<w, h>
"""
im_size = self.imgs[0].shape
assert init_tp in DICT_ATLAS_INIT
fn_init_atlas = DICT_ATLAS_INIT[init_tp]
if fn_init_atlas is not None:
if init_tp.startswith('OWS') or init_tp == 'GWS':
assert init_tp in DICT_ATLAS_INIT
fn_init_atlas = DICT_ATLAS_INIT[init_tp]
init_atlas = fn_init_atlas(imgs, nb_labels)
elif init_tp.startswith('msc') or init_tp == 'rnd':
assert init_tp in DICT_ATLAS_INIT
fn_init_atlas = DICT_ATLAS_INIT[init_tp]
init_atlas = fn_init_atlas(im_size, nb_labels)
elif init_tp == 'GT':
assert hasattr(self, 'gt_atlas')
......@@ -149,6 +155,9 @@ class ExperimentAPDL_base(expt_apd.ExperimentAPD):
assert hasattr(self, 'gt_atlas')
init_atlas = np.remainder(self.gt_atlas, nb_labels)
init_atlas = ptn_dict.initialise_atlas_deform_original(init_atlas)
assert init_atlas.max() < nb_labels
assert init_atlas.shape == im_size
assert init_atlas.dtype == np.int
return init_atlas
def _estimate_atlas(self, v):
......@@ -161,7 +170,7 @@ class ExperimentAPDL_base(expt_apd.ExperimentAPD):
logging.debug(' -> estimate atlas...')
self.params[self.iter_var_name] = v
logging.debug('PARAMS: %s', repr(self.params))
init_atlas = self._init_atlas(self.params['nb_labels'], self.params['init_tp'])
init_atlas = self._init_atlas(self.params['nb_labels'], self.params['init_tp'], self.imgs)
# prefix = 'expt_{}'.format(p['init_tp'])
path_out = os.path.join(self.params['path_exp'],
'debug_{}_{}'.format(self.iter_var_name, v))
......@@ -269,9 +278,9 @@ def experiments_synthetic(params=SYNTH_PARAMS):
expt = ExperimentAPDL_base(params)
expt.run(iter_var='case', iter_vals=range(params['nb_runs']))
# exp.run(iter_var='nb_labels', iter_vals=ptn_range)
del expt
except:
logging.error(traceback.format_exc())
del expt
tqdm_bar.update(1)
gc.collect(), time.sleep(1)
......@@ -306,6 +315,7 @@ def experiments_real(params=REAL_PARAMS):
expt = ExperimentAPDL_base(params)
# exp.run(gt=False, iter_var='case', iter_values=range(params['nb_runs']))
expt.run(gt=False, iter_var='nb_labels', iter_vals=NB_PATTERNS_REAL)
del expt
tqdm_bar.update(1)
gc.collect(), time.sleep(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment