Commit f62b4de8 authored by Jiri Borovec's avatar Jiri Borovec

update, according paper ACCV

parent 6f84e4cb
......@@ -51,16 +51,16 @@ PATH_RESULTS = '/datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_AP
DEFAULT_PARAMS = {
'computer': os.uname(),
'nb_samples': None,
'tol': 1e-2,
'init_tp': 'msc2', # msc. rnd
'max_iter': 250,
'tol': 1e-3,
'init_tp': 'msc', # msc. rnd
'max_iter': 25, # 250
'gc_regul': 0.,
'nb_labels': 2,
'nb_runs': NB_THREADS, # 500
'gc_reinit': True,
'ptn_split': False,
'ptn_compact': False,
'overlap_mj': False,
'ptn_compact': True,
'overlap_mj': True,
}
SYNTH_DATASET_NAME = 'atomicPatternDictionary_v0'
......
......@@ -127,7 +127,8 @@ def prototype_new_pattern(imgs, imgs_reconst, diffs, atlas,
# if ptn_size < 0.01:
# logging.debug('new patterns was too small %f', ptn_size)
# ptn = data.extract_image_largest_element(im_diff)
img_ptn = np.logical_and(img_ptn == True, atlas == 0)
img_ptn = np.logical_and(img_ptn == True)
# img_ptn = np.logical_and(img_ptn == True, atlas == 0)
return img_ptn
......
......@@ -52,7 +52,7 @@ def weights_image_atlas_overlap_major(img, atlas):
:return: [int] * nb_lbs of values {0, 1}
"""
# logging.debug('weights input image according given atlas')
weights = weights_image_atlas_overlap_threshold(img, atlas, 0.5)
weights = weights_image_atlas_overlap_threshold(img, atlas, 1.)
return weights
......@@ -64,11 +64,12 @@ def weights_image_atlas_overlap_partial(img, atlas):
"""
# logging.debug('weights input image according given atlas')
labels = np.unique(atlas).tolist()
weights = weights_image_atlas_overlap_threshold(img, atlas, (1. / len(labels)))
weights = weights_image_atlas_overlap_threshold(img, atlas,
(1. / np.max(labels)))
return weights
def weights_image_atlas_overlap_threshold(img, atlas, threshold=0.5):
def weights_image_atlas_overlap_threshold(img, atlas, threshold=1.):
""" estimate what patterns are activated with given atlas and input image
compute overlap matrix and eval nr of overlapping and non pixels and threshold
......@@ -85,9 +86,11 @@ def weights_image_atlas_overlap_threshold(img, atlas, threshold=0.5):
labels.remove(0)
weight = [0] * np.max(atlas)
for lb in labels:
equal = np.sum(img[atlas == lb])
total = np.sum(atlas == lb)
score = equal / float(total)
nequal = np.sum(abs(1 - img[atlas == lb]))
score = total / float(nequal) - 1.
# equal = np.sum(img[atlas == lb])
# score = equal / float(total)
if score >= threshold:
weight[lb - 1] = 1
return weight
......
......@@ -6,6 +6,13 @@ the reconstruction error to evaluate he parameters and export visualisation
EXAMPLE:
>> python run_apd_reconstruction.py \
--path_in /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APDL_real
>> python run_apd_reconstruction.py \
--path_in /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APD_real
>> python run_apd_reconstruction.py \
--path_in /datagrid/Medical/microscopy/drosophila/RESULTS/experiments_APD_real
>> python run_apd_reconstruction.py \
--path_in /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APD_temp \
--names_expt ExperimentALPE_mp_real_type_3_segm_reg_binary_gene_ssmall_20160509-155333 \
......@@ -33,12 +40,13 @@ import matplotlib.pyplot as plt
import pattern_weights as ptn_weight
import dataset_utils as gen_data
VISUAL = False
VISUAL = True
NB_THREADS = int(mproc.cpu_count() * .9)
# PATH_EXPERIMENTS = '/datagrid/Medical/microscopy/drosophila/RESULTS/' \
# 'experiments_APD_real'
PATH_EXPERIMENTS = '/datagrid/Medical/microscopy/drosophila/TEMPORARY/' \
PATH_EXPERIMENTS = '/datagrid/Medical/microscopy/drosophila/RESULTS/' \
'experiments_APD_real'
# PATH_EXPERIMENTS = '/datagrid/Medical/microscopy/drosophila/TEMPORARY/' \
# 'experiments_APD_real'
PATH_IMAGES_RGB = '/datagrid/Medical/microscopy/drosophila/TEMPORARY'
DICT_PARAMS = {
'path_in': PATH_EXPERIMENTS,
......@@ -94,7 +102,8 @@ def parse_arg_params(parser):
return args
def export_fig_reconstruction(path_out, name, segm_orig, segm_reconst, img_atlas):
def export_fig_reconstruction(path_out, name, segm_orig, segm_reconst, img_atlas,
img_rgb=None):
""" visualise reconstruction together with the original segmentation
:param path_out: str
......@@ -104,26 +113,41 @@ def export_fig_reconstruction(path_out, name, segm_orig, segm_reconst, img_atlas
:param img_atlas: np.array<height, width>
"""
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 8))
ax[0].set_title('original segmentation')
ax[0].imshow(1 - segm_orig, cmap='Greys', alpha=0.9)
ax[0].imshow(segm_reconst, alpha=0.1)
ax[0].set_title('original')
if img_rgb is not None:
ax[0].imshow(img_rgb, alpha=0.9)
else:
ax[0].imshow(1 - segm_orig, cmap='Greys', alpha=0.7)
ax[0].imshow(img_atlas, alpha=0.1)
ax[0].axis('off')
ax[0].contour(segm_reconst > 0, linewidth=2)
ax[0].contour(segm_reconst, linewidth=2, cmap=plt.cm.jet)
ax[0].contour(img_atlas > 0, linewidth=2)
ax[0].contour(img_atlas, levels=np.unique(img_atlas),
linewidth=2, cmap=plt.cm.jet)
ax[1].set_title('reconstructed segmentation')
ax[1].imshow(segm_reconst)
ax[1].axis('off')
ax[1].contour(segm_orig, linewidth=2, colors='w')
lut = plt.cm.get_cmap('jet', segm_reconst.max())(range(segm_reconst.max() + 1))
im = lut[segm_reconst]
im[segm_reconst == 0, :] = (1., 1., 1., 1.)
ax[1].imshow(im, alpha=0.8)
ax[1].axes.get_xaxis().set_ticklabels([])
ax[1].axes.get_yaxis().set_ticklabels([])
# ax[1].axis('off')
ax[1].contour(segm_orig, levels=np.unique(segm_orig), linewidth=2, colors='k')
segm_select = np.array(segm_reconst > 0, dtype=np.int)
segm_select[np.logical_and(img_atlas > 0, segm_reconst == 0)] = -1
segm_select[0, :2] = [-1, 1]
ax[2].set_title('selected vs not selected segm.')
ax[2].imshow(segm_select, cmap=plt.cm.RdYlGn)
ax[2].axis('off')
ax[2].contour(segm_orig, linewidth=3, colors='b')
ax[2].contour(img_atlas, linewidth=1, colors='w')
# segm_select = np.array(segm_reconst > 0, dtype=np.int)
# segm_select[np.logical_and(img_atlas > 0, segm_reconst == 0)] = -1
# segm_select[0, :2] = [-1, 1]
# ax[2].imshow(segm_select, cmap=plt.cm.RdYlGn)
im = np.ones(segm_reconst.shape + (4,))
im[np.logical_and(img_atlas > 0, segm_reconst > 0), :] = (0, 1, 0, 1)
im[np.logical_and(img_atlas > 0, segm_reconst == 0), :] = (1, 0, 0, 1)
ax[2].imshow(im, alpha=0.7)
ax[2].axes.get_xaxis().set_ticklabels([])
ax[2].axes.get_yaxis().set_ticklabels([])
# ax[2].axis('off')
ax[2].contour(img_atlas, levels=np.unique(img_atlas), linewidth=1, colors='w')
ax[2].contour(segm_orig, levels=np.unique(segm_orig), linewidth=3, colors='k')
p_fig = os.path.join(path_out, name + '.png')
fig.savefig(p_fig, bbox_inches='tight')
......@@ -144,6 +168,7 @@ def compute_reconstruction(img_name, dict_params, path_out, im_atlas,
:return: {str: float}
"""
segm_orig = load_segmentation(dict_params, img_name)
img_rgb = load_image_rgb(dict_params, img_name)
if weights is None:
# recompute encoding and then does the reconstruction
weights = ptn_weight.weights_image_atlas_overlap_major(segm_orig, im_atlas)
......@@ -155,7 +180,8 @@ def compute_reconstruction(img_name, dict_params, path_out, im_atlas,
logging.debug('segm unique: %s', repr(np.unique(segm_rect)))
if b_visu:
try:
export_fig_reconstruction(path_out, img_name, segm_orig, segm_rect, im_atlas)
export_fig_reconstruction(path_out, img_name, segm_orig, segm_rect,
im_atlas)
except:
# logging.error(traceback.format_exc())
logging.warning('drawing fail for "%s"...', img_name)
......@@ -214,6 +240,24 @@ def load_segmentation(dict_params, img_name):
return img
def load_image_rgb(dict_params, img_name):
""" load the segmenattion with values {0, 1}
:param dict_params: {str: values}
:param img_name: str
:return: np.array<height, width>
"""
path_img = os.path.join(dict_params['path_rgb'], img_name + '.png')
if img_name == 'mean':
return None
if os.path.exists(path_img):
img = np.array(Image.open(path_img))
else:
logging.warning('particular RGB image not exists "%s"', path_img)
img = None
return img
def export_fig_atlas(img_atlas, path_out, name, max_label=None):
""" export the atlas to given folder and specific name
......@@ -259,16 +303,22 @@ def perform_reconstruction_mproc(dict_params, name_csv, name_atlas, df_encode, i
assert np.max(img_atlas) <= len(list_patterns)
wrapper_reconstruction = partial(compute_reconstruction, dict_params=dict_params,
path_out=path_out, im_atlas=img_atlas, b_visu=VISUAL)
path_out=path_out, im_atlas=img_atlas, b_visu=b_visu)
results = []
tqdm_bar = tqdm.tqdm(total=len(df_encode))
if nb_jobs > 1:
logging.debug('computing %i samples in %i threads', len(df_encode), nb_jobs)
mproc_pool = mproc.Pool(nb_jobs)
results = mproc_pool.map(wrapper_reconstruction, df_encode.index)
for res in mproc_pool.imap_unordered(wrapper_reconstruction, df_encode.index):
results.append(res)
tqdm_bar.update()
mproc_pool.close()
mproc_pool.join()
else:
results = map(wrapper_reconstruction, df_encode.index)
for res in map(wrapper_reconstruction, df_encode.index):
results.append(res)
tqdm_bar.update()
df_diff = pd.DataFrame(results, columns=['image', name_csv])
df_diff.set_index('image', inplace=True)
......@@ -319,24 +369,28 @@ def recompute_encoding(config, atlas):
return df
def process_experiment(path_expt, nb_jobs=NB_THREADS):
def process_experiment(path_expt, nb_jobs=NB_THREADS, path_img_rgb=PATH_IMAGES_RGB):
""" process complete folder with experiment
:param path_expt: str
"""
logging.info('Experiment folder: \n "%s"', path_expt)
dict_params = load_config_json(path_expt)
dir_im = os.path.basename(dict_params['path_in'])
dict_params['path_rgb'] = os.path.join(path_img_rgb,
dir_im.replace('_segm_reg_binary', '_RGB_reg'))
atlas_names = [os.path.basename(p) for p
in glob.glob(os.path.join(path_expt, PREFIX_ATLAS + '*.png'))]
list_csv = [p for p in glob.glob(os.path.join(path_expt, PREFIX_ENCODE + '*.csv'))
if not p.endswith(POSIX_CSV_NEW)]
df_diffs_all = pd.DataFrame()
for path_csv in list_csv:
for i, path_csv in enumerate(list_csv):
name_csv = os.path.basename(path_csv)
name_atlas = find_relevant_atlas(name_csv, atlas_names)
logging.info('# %i / %i for Atlas: "%s" -> Encoding: "%s"',
i + 1, len(list_csv), name_atlas, name_csv)
if name_atlas is None:
continue
logging.info('Atlas: "%s" -> Encoding: "%s"', name_atlas, name_csv)
# load the atlas
path_atlas = os.path.join(path_expt, name_atlas)
img_atlas = load_atlas_image(path_atlas)
......@@ -367,11 +421,12 @@ def main():
list_expt = [os.path.join(arg_params['path_in'], n) for n in arg_params['names_expt']]
assert len(list_expt) > 0, 'No experiments found!'
tqdm_bar = tqdm.tqdm(total=len(list_expt))
for path_expt in list_expt:
# tqdm_bar = tqdm.tqdm(total=len(list_expt))
for i, path_expt in enumerate(list_expt):
logging.info('experiment %i / %i -> %s', i + 1, len(list_expt), path_expt)
process_experiment(path_expt, arg_params['nb_jobs'])
gc.collect(), time.sleep(1)
tqdm_bar.update(1)
# tqdm_bar.update(1)
logging.info('DONE')
......
......@@ -6,7 +6,12 @@ Example run:
>> python run_experiment_apd_all.py \
-in /datagrid/Medical/microscopy/drosophila/synthetic_data/atomicPatternDictionary_v1 \
-out /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APD-sta
-out /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APD
>> python run_experiment_apd_all.py \
-in /datagrid/Medical/microscopy/drosophila/synthetic_data/atomicPatternDictionary_v1 \
-out /datagrid/Medical/microscopy/drosophila/TEMPORARY/experiments_APDL_synth2
--method APDL
>> python run_experiment_apd_all.py --type real \
-in /datagrid/Medical/microscopy/drosophila/TEMPORARY/type_1_segm_reg_binary \
......
......@@ -58,11 +58,11 @@ DICT_ATLAS_INIT = {
}
# SIMPLE RUN
INIT_TYPES = ['GT', 'GTd']
GRAPHCUT_REGUL = [0.0, 1e-15, 1e-3]
INIT_TYPES = ['rnd', 'msc', 'msc2']
# GRAPHCUT_REGUL = [1e-2, 1e-1, 0.]
# COMPLEX RUN
# INIT_TYPES = DICT_ATLAS_INIT.keys()
# GRAPHCUT_REGUL = [0., 0e-12, 1e-9, 1e-6, 1e-3, 1e-1, 1.0]
GRAPHCUT_REGUL = [0., 0e-12, 1e-9, 1e-6, 1e-3, 1e-1]
def test_simple_show_case():
......@@ -290,8 +290,9 @@ def experiments_real(params=REAL_PARAMS):
l_params = [copy.deepcopy(params)]
if isinstance(params['dataset'], list):
l_params = expt_apd.extend_list_params(l_params, 'dataset', params['dataset'])
# l_params = expt_apd.extend_list_params(l_params, 'init_tp', ['msc', 'rnd'])
l_params = expt_apd.extend_list_params(l_params, 'init_tp', INIT_TYPES)
# l_params = expt_apd.extend_list_params(l_params, 'ptn_split', [True, False])
l_params = expt_apd.extend_list_params(l_params, 'ptn_compact', [True, False])
l_params = expt_apd.extend_list_params(l_params, 'gc_regul', GRAPHCUT_REGUL)
# l_params = expt_apd.extend_list_params(l_params, 'nb_labels',
# [5, 9, 12, 15, 20, 25, 30, 40])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment