Commit 6f84e4cb by Jiri Borovec

update, refactor, APDL reinit. patterns

parent 8475cb0c
......@@ -147,22 +147,24 @@ def create_clean_folder(path_dir):
return path_dir
def extract_image_largest_element(img):
def extract_image_largest_element(img_binary, labeled=None):
""" take a binary image and find all independent segments,
then keep just the largest segment and rest set as 0
:param img: np.array<w, h> of values {0, 1}
:param img_binary: np.array<w, h> of values {0, 1}
:return: np.array<w, h> of values {0, 1}
"""
labeled, nbObjects = ndimage.label(img)
if labeled is None or len(np.unique(labeled)) < 2:
labeled, _ = ndimage.label(img_binary)
areas = [(j, np.sum(labeled == j)) for j in np.unique(labeled)]
areas = sorted(areas, key=lambda x: x[1], reverse=True)
logger.debug('... elements area: %s', repr(areas))
img_ptn = img_binary.copy()
if len(areas) > 1:
img[:] = 0
img_ptn = np.zeros_like(img_binary)
# skip largest, assuming to be background
img[labeled == areas[1][0]] = 1
return img
img_ptn[labeled == areas[1][0]] = 1
return img_ptn
def atlas_filter_larges_components(atlas):
......
......@@ -51,13 +51,15 @@ PATH_RESULTS = '/datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_AP
DEFAULT_PARAMS = {
'computer': os.uname(),
'nb_samples': None,
'init_tp': 'msc1', # msc. rnd
'tol': 1e-2,
'init_tp': 'msc2', # msc. rnd
'max_iter': 250,
'gc_regul': 0.,
'nb_labels': 20,
'nb_labels': 2,
'nb_runs': NB_THREADS, # 500
'gc_reinit': True,
'ptn_split': True,
'ptn_split': False,
'ptn_compact': False,
'overlap_mj': False,
}
......@@ -77,7 +79,7 @@ SYNTH_PTN_RANGE = {
'atomicPatternDictionary_00': range(5),
'atomicPatternDictionary_v0': range(3, 15, 1),
'atomicPatternDictionary_v1': range(5, 20, 1),
'atomicPatternDictionary_v2': range(9, 40, 2),
'atomicPatternDictionary_v2': range(10, 40, 2) + [23],
'atomicPatternDictionary_v3': range(10, 40, 2),
'atomicPatternDictionary3D_v0': range(2, 14, 1),
'atomicPatternDictionary3D_v1': range(6, 30, 2),
......@@ -383,6 +385,7 @@ class ExperimentAPD_parallel(ExperimentAPD):
mproc_pool = mproc.Pool(self.nb_jobs)
for stat in mproc_pool.map(self._warp_perform_once, self.iter_values):
self.l_stat.append(stat)
self._evaluate()
# tqdm_bar.update(1)
mproc_pool.close()
mproc_pool.join()
......
......@@ -8,19 +8,19 @@ import logging
import numpy as np
def initialise_weights_random(nb_imgs, nb_lbs, ratio_sel=0.2):
def initialise_weights_random(nb_imgs, nb_labels, ratio_select=0.2):
"""
:param nb_imgs: int, numer of all images
:param nb_lbs: int, numer of all avalaible labels
:param ratio_sel: float<0, 1> defining how many should be set on,
:param nb_labels: int, numer of all avalaible labels
:param ratio_select: float<0, 1> defining how many should be set on,
1 means all and 0 means none
:return: np.array<nb_imgs, nb_lbs>
:return: np.array<nb_imgs, nb_labels>
"""
logging.debug('initialise weights for %i images and %i labels '
'as random selection', nb_imgs, nb_lbs)
prob = np.random.random((nb_imgs, nb_lbs))
'as random selection', nb_imgs, nb_labels)
prob = np.random.random((nb_imgs, nb_labels))
weights = np.zeros_like(prob)
weights[prob <= ratio_sel] = 1
weights[prob <= ratio_select] = 1
return weights
......@@ -68,13 +68,13 @@ def weights_image_atlas_overlap_partial(img, atlas):
return weights
def weights_image_atlas_overlap_threshold(img, atlas, thr=0.5):
def weights_image_atlas_overlap_threshold(img, atlas, threshold=0.5):
""" estimate what patterns are activated with given atlas and input image
compute overlap matrix and eval nr of overlapping and non pixels and threshold
:param img: np.array<w, h>
:param atlas: np.array<w, h>
:param thr: float, represent the ration between overlapping and non pixels
:param threshold: float, represent the ration between overlapping and non pixels
:return: [int] * nb_lbs of values {0, 1}
"""
# logging.debug('weights input image according given atlas')
......@@ -88,7 +88,7 @@ def weights_image_atlas_overlap_threshold(img, atlas, thr=0.5):
equal = np.sum(img[atlas == lb])
total = np.sum(atlas == lb)
score = equal / float(total)
if score >= thr:
if score >= threshold:
weight[lb - 1] = 1
return weight
......
......@@ -35,8 +35,10 @@ import dataset_utils as gen_data
VISUAL = False
NB_THREADS = int(mproc.cpu_count() * .9)
PATH_EXPERIMENTS = '/datagrid/Medical/microscopy/drosophila/' \
'TEMPORARY/experiments_APDL_real'
# PATH_EXPERIMENTS = '/datagrid/Medical/microscopy/drosophila/RESULTS/' \
# 'experiments_APD_real'
PATH_EXPERIMENTS = '/datagrid/Medical/microscopy/drosophila/TEMPORARY/' \
'experiments_APD_real'
DICT_PARAMS = {
'path_in': PATH_EXPERIMENTS,
......@@ -103,19 +105,26 @@ def export_fig_reconstruction(path_out, name, segm_orig, segm_reconst, img_atlas
"""
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 8))
ax[0].set_title('original segmentation')
ax[0].imshow(1 - segm_orig, cmap='Greys'), ax[0].axis('off')
ax[0].imshow(1 - segm_orig, cmap='Greys', alpha=0.9)
ax[0].imshow(segm_reconst, alpha=0.1)
ax[0].axis('off')
ax[0].contour(segm_reconst > 0, linewidth=2)
ax[0].contour(segm_reconst, linewidth=2, cmap=plt.cm.jet)
ax[1].set_title('reconstructed segmentation')
ax[1].imshow(segm_reconst), ax[1].axis('off')
ax[1].contour(segm_orig, linewidth=2, cmap=plt.cm.Greens)
ax[1].imshow(segm_reconst)
ax[1].axis('off')
ax[1].contour(segm_orig, linewidth=2, colors='w')
segm_select = np.array(segm_reconst > 0, dtype=np.int)
segm_select[np.logical_and(img_atlas > 0, segm_reconst == 0)] = -1
segm_select[0, :2] = [-1, 1]
ax[2].set_title('selected vs not selected segm.')
ax[2].imshow(segm_select, cmap=plt.cm.RdYlGn), ax[2].axis('off')
ax[2].contour(segm_orig, linewidth=3, cmap=plt.cm.Blues)
ax[2].contour(img_atlas, linewidth=1, cmap=plt.cm.Pastel1)
ax[2].imshow(segm_select, cmap=plt.cm.RdYlGn)
ax[2].axis('off')
ax[2].contour(segm_orig, linewidth=3, colors='b')
ax[2].contour(img_atlas, linewidth=1, colors='w')
p_fig = os.path.join(path_out, name + '.png')
fig.savefig(p_fig, bbox_inches='tight')
plt.close(fig)
......@@ -224,7 +233,7 @@ def export_fig_atlas(img_atlas, path_out, name, max_label=None):
plt.close(fig)
def perform_reconstruction_mproc(dict_params, name_csv, df_encode, img_atlas,
def perform_reconstruction_mproc(dict_params, name_csv, name_atlas, df_encode, img_atlas,
nb_jobs=NB_THREADS, b_visu=False):
""" perform the reconstruction in multi process mode
......@@ -236,6 +245,7 @@ def perform_reconstruction_mproc(dict_params, name_csv, df_encode, img_atlas,
:param b_visu: bool
:return: DF
"""
export_fig_atlas(img_atlas, dict_params['path_exp'], name_atlas)
path_out = os.path.join(dict_params['path_exp'],
name_csv.replace(PREFIX_ENCODE, PREFIX_RECONST))
if b_visu:
......@@ -324,12 +334,15 @@ def process_experiment(path_expt, nb_jobs=NB_THREADS):
for path_csv in list_csv:
name_csv = os.path.basename(path_csv)
name_atlas = find_relevant_atlas(name_csv, atlas_names)
if name_atlas is None:
continue
logging.info('Atlas: "%s" -> Encoding: "%s"', name_atlas, name_csv)
# load the atlas
path_atlas = os.path.join(path_expt, name_atlas)
img_atlas = load_atlas_image(path_atlas)
df_encode = pd.DataFrame.from_csv(path_csv)
df_diff = perform_reconstruction_mproc(dict_params, name_csv.replace('.csv', ''),
os.path.splitext(name_atlas)[0],
df_encode, img_atlas, nb_jobs, VISUAL)
df_diffs_all = pd.concat([df_diffs_all, df_diff], axis=1)
df_diffs_all.to_csv(os.path.join(path_expt, CSV_RECONT_DIFF))
......
......@@ -119,9 +119,9 @@ def experiment_pipeline_alpe_showcase(path_out):
init_atlas_msc = ptn_dict.initialise_atlas_mosaic(atlas.shape, np.max(atlas))
init_encode_rnd = ptn_weigth.initialise_weights_random(len(imgs), np.max(atlas))
atlas, w_bins = dl.alpe_pipe_atlas_learning_ptn_weights(imgs,
out_prefix='mosaic', init_atlas=init_atlas_msc,
max_iter=9, out_dir=path_out)
atlas, w_bins = dl.apdl_pipe_atlas_learning_ptn_weights(imgs,
out_prefix='mosaic', init_atlas=init_atlas_msc,
max_iter=9, out_dir=path_out)
return atlas, w_bins
......@@ -130,7 +130,7 @@ class ExperimentAPDL_base(expt_apd.ExperimentAPD):
the main real experiment or our Atlas Learning Pattern Encoding
"""
def _init_atlas(self, nb_lbs, init_tp):
def _init_atlas(self, nb_labels, init_tp):
""" init atlas according an param
:param nb_lbs: int
......@@ -141,13 +141,13 @@ class ExperimentAPDL_base(expt_apd.ExperimentAPD):
assert init_tp in DICT_ATLAS_INIT
fn_init_atlas = DICT_ATLAS_INIT[init_tp]
if fn_init_atlas is not None:
init_atlas = fn_init_atlas(im_size, nb_lbs)
init_atlas = fn_init_atlas(im_size, nb_labels)
elif init_tp == 'GT':
assert hasattr(self, 'gt_atlas')
init_atlas = np.remainder(self.gt_atlas, nb_lbs)
init_atlas = np.remainder(self.gt_atlas, nb_labels)
elif init_tp == 'GTd':
assert hasattr(self, 'gt_atlas')
init_atlas = np.remainder(self.gt_atlas, nb_lbs)
init_atlas = np.remainder(self.gt_atlas, nb_labels)
init_atlas = ptn_dict.initialise_atlas_deform_original(init_atlas)
return init_atlas
......@@ -168,13 +168,15 @@ class ExperimentAPDL_base(expt_apd.ExperimentAPD):
if isinstance(self.params['nb_samples'], float):
self.params['nb_samples'] = int(len(self.imgs) * self.params['nb_samples'])
try:
atlas, w_bins = dl.alpe_pipe_atlas_learning_ptn_weights(
atlas, w_bins = dl.apdl_pipe_atlas_learning_ptn_weights(
self.imgs[:self.params['nb_samples']],
init_atlas=init_atlas,
tol=self.params['tol'],
gc_reinit=self.params['gc_reinit'],
gc_coef=self.params['gc_regul'],
max_iter=self.params['max_iter'],
ptn_split=self.params['ptn_split'],
ptn_compact=self.params['ptn_compact'],
overlap_major=self.params['overlap_mj'],
out_dir=path_out) # , out_prefix=prefix
except:
......@@ -244,12 +246,14 @@ def experiments_synthetic(params=SYNTH_PARAMS):
logging.info('PARAMS: \n%s', '\n'.join(['"{}": \n\t {}'.format(k, v)
for k, v in arg_params.iteritems()]))
params.update(arg_params)
params.update({'max_iter': 25})
l_params = [params]
if isinstance(params['dataset'], list):
l_params = expt_apd.extend_list_params(l_params, 'dataset', params['dataset'])
l_params = expt_apd.extend_list_params(l_params, 'init_tp', INIT_TYPES)
l_params = expt_apd.extend_list_params(l_params, 'ptn_split', [True, False])
l_params = expt_apd.extend_list_params(l_params, 'ptn_compact', [True, False])
l_params = expt_apd.extend_list_params(l_params, 'gc_regul', GRAPHCUT_REGUL)
ptn_range = SYNTH_PTN_RANGE[os.path.basename(params['path_in'])]
l_params = expt_apd.extend_list_params(l_params, 'nb_labels', ptn_range)
......
......@@ -16,8 +16,10 @@ import time
import logging
import traceback
import multiprocessing as mproc
from functools import partial
# to suppress all visu, has to be on the beginning
import tqdm
import matplotlib
matplotlib.use('Agg')
import numpy as np
......@@ -311,16 +313,16 @@ def export_debug_images(p_out, n_img, d_debug):
io.imsave(os.path.join(p_debug, 'im_uc_{}.png'.format(i)), im_uc)
def segment_image(params, dict_paths, p_im):
def segment_image(path_img, params, dict_paths):
""" segment individual image
:param params: {str: ...}
:param p_im: str
:param path_img: str
:param dict_paths: {str: ...}
:param visu: bool
"""
n_img = os.path.basename(p_im)
img_raw = io.imread(p_im)
n_img = os.path.basename(path_img)
img_raw = io.imread(path_img)
img = preprocessing_image(img_raw)
io.imsave(os.path.join(dict_paths['path_norm'], n_img), img)
......@@ -350,16 +352,6 @@ def segment_image(params, dict_paths, p_im):
visual_pipeline(dict_paths, img_raw, img, seg_raw, seg, n_img)
def wrapper_segment_image(mp_set):
if b_debug:
segment_image(*mp_set)
else:
try:
logging.debug('run segment. in try-catch mode...')
segment_image(*mp_set)
except: pass
def segment_image_folder(params=DEFAULT_PARAMS, dict_paths=DEFAULT_PATHS,
im_pattern='*.png', nb_jobs=1):
""" segment complete image folder
......@@ -372,19 +364,24 @@ def segment_image_folder(params=DEFAULT_PARAMS, dict_paths=DEFAULT_PATHS,
check_create_dirs(dict_paths, ['path_in'], ['path_norm', 'path_result'])
with open(os.path.join(dict_paths['path_result'], 'config.txt'), 'w') as f:
f.write(tl_expt.string_dict(params))
p_imgs = sorted(glob.glob(os.path.join(dict_paths['path_in'], im_pattern)))
logger.info('found %i images', len(p_imgs))
paths_img = sorted(glob.glob(os.path.join(dict_paths['path_in'], im_pattern)))
logger.info('found %i images', len(paths_img))
tqdm_bar = tqdm.tqdm(total=len(paths_img))
wrapper_segment_image = partial(segment_image,
params=params, dict_paths=dict_paths)
# TODO: use functools.partial and tqdm bar
mp_set = ((params, dict_paths, p_im) for p_im in p_imgs)
if nb_jobs > 1:
logger.debug('perform_sequence in %i threads', nb_jobs)
logger.debug('perform_sequence in %i threads', paths_img)
mproc_pool = mproc.Pool(nb_jobs)
mproc_pool.map(wrapper_segment_image, mp_set)
for x in mproc_pool.map(wrapper_segment_image, paths_img):
tqdm_bar.update()
mproc_pool.close()
mproc_pool.join()
else:
map(wrapper_segment_image, mp_set)
for path_img in paths_img:
segment_image(path_img, params, dict_paths)
tqdm_bar.update()
def mproc_visual_pair_orig_segm(mp_set):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment