Commit 90235a72 authored by Jiri Borovec's avatar Jiri Borovec

fixing

parent 4637da03
......@@ -49,13 +49,13 @@ else:
PATH_OUTPUT = '/home/jirka/TEMP'
NB_THREADS = int(mproc.cpu_count() * .9)
PATH_RESULTS = os.path.join(PATH_OUTPUT, 'experiments_APD')
PATH_RESULTS = os.path.join(PATH_OUTPUT, 'experiments_APD_temp')
DEFAULT_PARAMS = {
'computer': os.uname(),
'init_tp': 'msc', # msc. rnd
'max_iter': 39,
'gc_regul': 0.,
'nb_lbs': 7,
'gc_regul': 1e-2,
'nb_lbs': 12,
'nb_runs': 25, # 500
'gc_reinit': True,
'ptn_split': True,
......@@ -95,7 +95,7 @@ REAL_PARAMS.update({'data_type': 'real',
'dataset': REAL_DATASET_NAME,
'sub_dataset': REAL_SUB_DATASETS[0],
'path_out': PATH_RESULTS,
'max_iter': 150,
'max_iter': 50,
'nb_runs': 10})
# PATH_OUTPUT = os.path.join('..','..','results')
......
......@@ -407,6 +407,7 @@ def dataset_load_images(name=DEFAULT_DATASET, path_base=DEFAULT_PATH_APD,
nb_load_blocks = len(path_imgs) / BLOCK_NB_LOAD_IMAGES
logger.debug('estimated %i loading blocks', nb_load_blocks)
list_path_im = [path_imgs[i::nb_load_blocks] for i in range(nb_load_blocks)]
mproc_pool = mproc.Pool(nb_jobs)
list_names_imgs = mproc_pool.map(mproc_load_images, list_path_im)
mproc_pool.close()
......@@ -426,7 +427,7 @@ def load_image(path_img):
n_img = os.path.splitext(os.path.basename(path_img))[0]
# img = io.imread(path_img)
img = np.array(Image.open(path_img))
img = img / float(img.max())
img /= float(img.max())
return n_img, img
......@@ -479,7 +480,7 @@ def dataset_export_images(p_out, imgs, names=None, nb_jobs=1):
if names is None:
names = range(len(imgs))
mp_set = [(p_out, im, names[i]) for i, im in enumerate(imgs)]
mp_set = [(p_out, im, names[i]) for i, im in enumerate(sorted(imgs))]
if nb_jobs > 1:
logger.debug('running in %i threads...', nb_jobs)
mproc_pool = mproc.Pool(nb_jobs)
......@@ -489,12 +490,12 @@ def dataset_export_images(p_out, imgs, names=None, nb_jobs=1):
else:
logger.debug('running in single thread...')
map(mproc_wrapper, mp_set)
try:
path_npz = os.path.join(p_out, 'input_images.npz')
np.savez(open(path_npz, 'w'), imgs)
except:
logger.error(traceback.format_exc())
os.remove(path_npz)
# try:
# path_npz = os.path.join(p_out, 'input_images.npz')
# np.savez(open(path_npz, 'w'), imgs)
# except:
# logger.error(traceback.format_exc())
# os.remove(path_npz)
def mproc_wrapper(mp_set):
......@@ -512,7 +513,7 @@ def dataset_convert_nifti(path_in, path_out, posix=DEFAULT_IM_POSIX):
logger.info('convert a dataset to Nifti')
p_imgs = glob.glob(os.path.join(path_in, '*' + posix))
create_clean_folder(path_out)
# p_imgs = sorted(p_imgs)
p_imgs = sorted(p_imgs)
for path_im in p_imgs:
name = os.path.splitext(os.path.basename(path_im))[0]
path_out = os.path.join(path_out, name)
......
......@@ -73,6 +73,7 @@ def reconstruct_samples(atlas, w_bins):
imgs = [None] * w_bins.shape[0]
for i, w in enumerate(w_bin_ext):
imgs[i] = np.asarray(w)[np.asarray(atlas)]
assert atlas.shape == imgs[i].shape
return imgs
......
......@@ -15,39 +15,52 @@ import multiprocessing as mproc
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
# import pattern_weights as ptn_weight
logger = logging.getLogger(__name__)
b_debug = False
VISUAL = True
PATH_BASE = '/datagrid/Medical/microscopy/drosophila/'
PATH_EXPERIMENTS = os.path.join(PATH_BASE, 'TEMPORARY', 'experiments_APD_new')
NAME_EXPERIMENT = 'ExperimentALPE_mp_real_type_1_segm_reg_binary_gene_ssmall'
SAMPLE_PATH_EXPERIMENT = os.path.join(PATH_EXPERIMENTS, NAME_EXPERIMENT)
if b_debug:
NB_THREADS = 1
PATH_EXPERIMENTS = os.path.join(PATH_BASE, 'TEMPORARY', 'experiments_APD_temp')
SAMPLE_PATH_EXPERIMENT = os.path.join(PATH_EXPERIMENTS,
'ExperimentALPE_mp_real_type_3_segm_reg_binary_gene_ssmall_20160509-045213')
else:
NB_THREADS = int(mproc.cpu_count() * .8)
PATH_EXPERIMENTS = os.path.join(PATH_BASE, 'TEMPORARY', 'experiments_APD_new')
SAMPLE_PATH_EXPERIMENT = ''
NAME_CONFIG = 'config.json'
PREFIX_ATLAS = 'atlas_'
PREFIX_ENCODE = 'encoding_'
PREFIX_RECONST = 'reconstruct_'
NB_THREADS = int(mproc.cpu_count() * .8)
VISUAL = True
def draw_reconstruction(path_out, name, segm_orig, segm_rect):
def draw_reconstruction(path_out, name, segm_orig, segm_reconst, img_atlas):
""" visualise reconstruction together with the original segmentation
:param path_out: str
:param name: str
:param segm_orig: np.array<height, width>
:param segm_rect: np.array<height, width>
:param segm_reconst: np.array<height, width>
:param img_atlas: np.array<height, width>
"""
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(24, 12))
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 8))
ax[0].set_title('original segmentation')
ax[0].imshow(1 - segm_orig, cmap='Greys'), ax[0].axis('off')
ax[0].contour(segm_rect, linewidth=2, cmap=plt.cm.jet)
ax[0].contour(segm_reconst > 0, linewidth=2)
ax[0].contour(segm_reconst, linewidth=2, cmap=plt.cm.jet)
ax[1].set_title('reconstructed segmentation')
ax[1].imshow(segm_rect), ax[1].axis('off')
ax[1].imshow(segm_reconst), ax[1].axis('off')
ax[1].contour(segm_orig, linewidth=2)
segm_select = np.array(segm_reconst > 0, dtype=np.int)
segm_select[np.logical_and(img_atlas > 0, segm_reconst == 0)] = -1
ax[2].set_title('selected vs not selected segm.')
ax[2].imshow(segm_select, cmap=plt.cm.RdYlGn), ax[2].axis('off')
ax[2].contour(segm_orig, linewidth=2, cmap=plt.cm.Blues)
p_fig = os.path.join(path_out, name + '.png')
fig.savefig(p_fig, bbox_inches='tight')
plt.close(fig)
......@@ -65,14 +78,14 @@ def compute_reconstruction(dict_params, path_out, img_atlas, weights, img_name):
"""
segm_orig = load_segmentation(dict_params, img_name)
segm_rect = np.zeros_like(img_atlas)
for i, v in enumerate(weights):
for i, w in enumerate(weights):
lb = i + 1
if v == 1:
if w == 1:
segm_rect[img_atlas == lb] = lb
logger.debug('segm unique: %s', repr(np.unique(segm_rect)))
if VISUAL:
try:
draw_reconstruction(path_out, img_name, segm_orig, segm_rect)
draw_reconstruction(path_out, img_name, segm_orig, segm_rect, img_atlas)
except:
logger.error(traceback.format_exc())
segm_bin = (segm_rect >= 1)
......@@ -95,17 +108,15 @@ def find_relevant_atlas(name_csv, list_names_atlas):
return list_names_atlas[idx]
def load_atlas_image(path_expt, name_atlas):
def load_atlas_image(path_atlas):
""" load the atlas as norm labels to be small natural ints
:param path_expt: str
:param name_atlas: str
:param path_atlas: str
:return: np.array<height, width>
"""
path_atlas = os.path.join(path_expt, name_atlas)
img_atlas = np.array(Image.open(path_atlas))
# norm image to have labels as [0, 1, 2, ...]
img_atlas /= (img_atlas.max() / len(np.unique(img_atlas)))
img_atlas /= (img_atlas.max() / (len(np.unique(img_atlas)) + 1))
# subtract background (set it as -1)
# img_atlas -= 1
logger.debug('Atlas: %s with labels: %s', repr(img_atlas.shape),
......@@ -150,8 +161,16 @@ def perform_reconstruction(dict_params, df_encode, img_atlas):
def mproc_wrapper(mp_tuple):
# encoding from the input csv
return compute_reconstruction(*mp_tuple)
# def mproc_wrapper(mp_tuple):
# # recompute encoding and then does the reconstruction
# dict_params, path_out, img_atlas, weights, img_name = mp_tuple
# segm_orig = load_segmentation(dict_params, img_name)
# weights = ptn_weight.weights_image_atlas_overlap_major(segm_orig, img_atlas)
# return compute_reconstruction(dict_params, path_out, img_atlas, weights, img_name)
def perform_reconstruction_mproc(dict_params, name_csv, df_encode, img_atlas):
path_out = os.path.join(dict_params['path_exp'],
......@@ -160,16 +179,19 @@ def perform_reconstruction_mproc(dict_params, name_csv, df_encode, img_atlas):
os.mkdir(path_out)
list_patterns = [col for col in df_encode.columns if col.startswith('ptn ')]
logger.debug('list of pattern names: %s', repr(list_patterns))
list_idxs = [int(col[3:]) for col in list_patterns]
assert list_idxs == sorted(list_idxs)
mp_tuples = ((dict_params, path_out, img_atlas, row[list_patterns].values, idx)
for idx, row in df_encode.iterrows())
mproc_pool = mproc.Pool(NB_THREADS)
results = mproc_pool.map(mproc_wrapper, mp_tuples)
mproc_pool.close()
mproc_pool.join()
# results = map(mproc_wrapper, mp_tuples)
if NB_THREADS > 1:
mproc_pool = mproc.Pool(NB_THREADS)
results = mproc_pool.map(mproc_wrapper, mp_tuples)
mproc_pool.close()
mproc_pool.join()
else:
results = map(mproc_wrapper, mp_tuples)
df_diff = pd.DataFrame(results, columns=['image', name_csv])
df_diff = df_diff.set_index('image')
......@@ -194,8 +216,9 @@ def process_experiment(path_expt=SAMPLE_PATH_EXPERIMENT):
name_csv = os.path.basename(path_csv)
name_atlas = find_relevant_atlas(name_csv, atlas_names)
logger.info('Atlas: "%s" -> Encoding: "%s"', name_atlas, name_csv)
# load the tlas
img_atlas = load_atlas_image(path_expt, name_atlas)
# load the atlas
path_atlas = os.path.join(path_expt, name_atlas)
img_atlas = load_atlas_image(path_atlas)
df_encode = pd.DataFrame.from_csv(path_csv)
df_diff = perform_reconstruction_mproc(dict_params, name_csv.replace('.csv', ''),
df_encode, img_atlas)
......@@ -217,7 +240,7 @@ def main(path_base=PATH_EXPERIMENTS):
list_expt = [p for p in glob.glob(os.path.join(path_base, '*'))
if os.path.isdir(p)]
for i, path_expt in enumerate(list_expt):
logger.info('processing experiment %i / %', i + 1, len(list_expt))
logger.info('processing experiment %i / %i', i + 1, len(list_expt))
process_experiment(path_expt)
......@@ -225,7 +248,9 @@ if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
logger.info('running...')
# process_experiment()
main()
if b_debug:
process_experiment()
else:
main()
logger.info('DONE')
\ No newline at end of file
......@@ -238,7 +238,7 @@ def experiments_real(dict_data, dict_params=REAL_PARAMS, nb_jobs=NB_THREADS):
l_params = experiment_apd.extend_l_params(l_params, 'init_tp', ['msc', 'rnd'])
l_params = experiment_apd.extend_l_params(l_params, 'ptn_split', [True, False])
l_params = experiment_apd.extend_l_params(l_params, 'gc_regul',
[0., 1e-3, 1e-2, 1e-1, 5e-1])
[0., 1e-3, 1e-2, 1e-1])
# l_params = experiment_apd.extend_l_params(l_params, 'nb_lbs',
# [5, 9, 12, 15, 20, 25, 30, 40])
logger.debug('list params: %i', len(l_params))
......@@ -250,7 +250,7 @@ def experiments_real(dict_data, dict_params=REAL_PARAMS, nb_jobs=NB_THREADS):
exp = ExperimentALPE(params)
# exp.run(gt=False, iter_var='case', iter_values=range(params['nb_runs']))
exp.run(gt=False, iter_var='nb_lbs',
iter_vals=[5, 9, 12, 15, 20, 25, 30, 40, 50])
iter_vals=[9, 12, 15, 20, 25, 30])
def main_real(nb_jobs=NB_THREADS):
......@@ -278,9 +278,9 @@ if __name__ == "__main__":
# test_atlasLearning(atlas, imgs, encoding)
# experiments_test()
# experiments_synthetic()
experiments_synthetic()
main_real(nb_jobs=NB_THREADS)
# main_real(nb_jobs=NB_THREADS)
logger.info('DONE')
# plt.show()
......@@ -170,8 +170,8 @@ def visual_pipeline(d_paths, img, im_norm, seg_raw, seg, n_img):
ax[1, 1].set_title('raw segmentation')
ax[1, 1].imshow(seg_raw), ax[1, 1].axis('off')
# ax[0].imshow(seg_sml), plt.axis('off')
p_fig = os.path.join(d_paths['p_visu'], PREFIX_VISU_PIPELINE + n_img)
fig.savefig(p_fig, bbox_inches='tight')
path_fig = os.path.join(d_paths['p_visu'], PREFIX_VISU_PIPELINE + n_img)
fig.savefig(path_fig, bbox_inches='tight')
plt.close(fig)
......
......@@ -20,7 +20,7 @@ import generate_dataset as gen_data
logger = logging.getLogger(__name__)
# DEFAULT_PATH_DATA = '/jirka/jirka/TEMP/APD_real_data'
# PATH_DATA = '/jirka/jirka/TEMP/APD_real_data'
DEFAULT_PATH_DATA = '/datagrid/Medical/microscopy/drosophila/'
# REAL_DATASET_NAME = '1000_ims'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment