Commit c5746c43 authored by Jirka's avatar Jirka

fix & expt STA update

parent c251b4f0
......@@ -3,6 +3,7 @@ import numpy as np
import ptn_disctionary as ptn_dict
import ptn_weights as ptn_weight
import similarity_metric as sim_metric
import generate_dataset as gen_data
import matplotlib.pyplot as plt
import logging
logger = logging.getLogger(__name__)
......@@ -225,8 +226,8 @@ def export_visualization_image(img, i, out_dir, prefix='debug', name='',
plt.imshow(img, interpolation='none', aspect=ration)
plt.xlabel(labels[0])
plt.ylabel(labels[1])
n_fig = 'APDL_{}_{}_iter_{:04d}.png'.format(prefix, name, i)
p_fig = os.path.join(out_dir, n_fig)
n_fig = 'APDL_{}_{}_iter_{:04d}'.format(prefix, name, i)
p_fig = os.path.join(out_dir, n_fig + '.png')
logger.debug('.. export Vusialization as "{}...{}"'.format(p_fig[:19], p_fig[-19:]))
fig.savefig(p_fig, bbox_inches='tight', pad_inches=0.05)
plt.close()
......@@ -247,11 +248,13 @@ def export_visual_atlas(i, out_dir, atlas=None, weights=None, prefix='debug'):
logger.debug('output path "{}" does not exist'.format(out_dir))
return None
if atlas is not None:
export_visualization_image(atlas, i, out_dir, prefix, 'atlas',
labels=['X', 'Y'])
if weights is not None:
export_visualization_image(weights, i, out_dir, prefix, 'weights',
'auto', ['patterns', 'images'])
# export_visualization_image(atlas, i, out_dir, prefix, 'atlas',
# labels=['X', 'Y'])
n_img = 'APDL_{}_atlas_iter_{:04d}'.format(prefix, i)
gen_data.export_image(out_dir, atlas, n_img)
# if weights is not None:
# export_visualization_image(weights, i, out_dir, prefix, 'weights',
# 'auto', ['patterns', 'images'])
return None
......
......@@ -17,13 +17,13 @@ import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import matplotlib.gridspec as gridspec
import multiprocessing as mproc
import generate_dataset as gen_data
import dictionary_learning as dl
import ptn_disctionary as ptn_dict
import ptn_weights as ptn_weigth
import experiments_sta as exp_sta
import logging
import multiprocessing as mproc
import copy_reg
import types
logger = logging.getLogger(__name__)
......@@ -160,7 +160,7 @@ class ExperimentALPE(exp_sta.ExperimentAPD):
init_atlas = ptn_dict.initialise_atlas_random(im_size, nb_lbs)
return init_atlas
def _estimate_atlas(self, i):
def _estimate_atlas(self, v):
""" set all params and run the atlas estimation in try mode
:param i: int, index of try
......@@ -170,7 +170,8 @@ class ExperimentALPE(exp_sta.ExperimentAPD):
p = self.params.copy()
init_atlas = self._init_atlas()
# prefix = 'expt_{}'.format(p['init_tp'])
p_out = os.path.join(p['exp_path'], 'case_{:05d}'.format(i))
name = '{}_{:05d}'.format(self.iter_n_var, v)
p_out = os.path.join(p['exp_path'], name)
if not os.path.exists(p_out):
os.mkdir(p_out)
try:
......@@ -187,71 +188,15 @@ class ExperimentALPE(exp_sta.ExperimentAPD):
img_rct = ptn_dict.reconstruct_samples(atlas, w_bins)
return atlas, img_rct
def _perform(self, n_iter='nb_runs'):
self.l_stat = []
for i in range(self.params.get(n_iter)):
atlas, img_rct = self._estimate_atlas(i)
stat = self._compute_statistic_gt(atlas, img_rct)
stat.update({'idx': i})
self.l_stat.append(stat)
return
def _evaluate(self):
self.df_stat = pd.DataFrame()
for stat in self.l_stat:
self.df_stat = self.df_stat.append(stat, ignore_index=True)
if 'idx' in stat.keys():
self.df_stat = self.df_stat.set_index('idx')
p_csv = os.path.join(self.params.get('exp_path'), self.RESULTS_CSV)
logger.debug('save results: "{}"'.format(p_csv))
self.df_stat.to_csv(p_csv)
return
class ExperimentALPE_mp(ExperimentALPE):
def __init__(self, params, nb_jobs=1):
super(ExperimentALPE_mp, self).__init__(params)
self.nb_jobs = nb_jobs
def _perform_once(self, i):
atlas, img_rct = self._estimate_atlas(i)
def _perform_once(self, v):
atlas, img_rct = self._estimate_atlas(v)
stat = self._compute_statistic_gt(atlas, img_rct)
stat.update({'idx': i})
stat.update({self.iter_n_var: v})
return stat
def _perform(self, n_iter='nb_runs'):
# ISSUE with passing large date to processes so the images are saved
# and loaded in particular process again
# p_imgs = os.path.join(self.params.get('exp_path'), 'input_images.npz')
# np.savez(open(p_imgs, 'w'), imgs=self.imgs)
mproc_pool = mproc.Pool(self.nb_jobs)
self.l_stat = mproc_pool.map(self._perform_once,
range(self.params.get(n_iter)))
mproc_pool.close()
mproc_pool.join()
# remove temporary image file
# os.remove(p_imgs)
return
def extend_l_params(l_params, n_param, l_options):
""" extend the parameter list by all sub-datasets
:param l_params: [{str: ...}]
:param n_param: str
:param l_options: list
:return: [{str: ...}]
"""
l_params_new = []
for p in l_params:
for v in l_options:
p_new = p.copy()
p_new.update({n_param: v})
l_params_new.append(p_new)
return l_params_new
class ExperimentALPE_mp(ExperimentALPE, exp_sta.ExperimentAPD_mp):
pass
def experiments_synthetic(dataset=None):
......@@ -264,12 +209,12 @@ def experiments_synthetic(dataset=None):
params.update({'dataset': dataset})
l_params = [params]
l_params = extend_l_params(l_params, 'init_tp', ['msc', 'rnd'])
l_params = extend_l_params(l_params, 'ptn_split', [True, False])
# l_params = extend_l_params(l_params, 'overlap_mj', [True, False])
l_params = extend_l_params(l_params, 'sub_dataset', exp_sta.SYNTH_SUB_DATASETS)
l_params = exp_sta.extend_l_params(l_params, 'sub_dataset', exp_sta.SYNTH_SUB_DATASETS)
l_params = exp_sta.extend_l_params(l_params, 'init_tp', ['msc', 'rnd'])
l_params = exp_sta.extend_l_params(l_params, 'ptn_split', [True, False])
range_nb_lbs = exp_sta.SYNTH_PTN_RANGE[exp_sta.SYNTH_DATASET_VERSION]
l_params = extend_l_params(l_params, 'nb_lbs', range_nb_lbs)
l_params = exp_sta.extend_l_params(l_params, 'nb_lbs', range_nb_lbs)
l_params = exp_sta.extend_l_params(l_params, 'gc_regul', [0., 1e-3, 1e-1, 1e0])
logger.debug('list params: {}'.format(len(l_params)))
......@@ -292,11 +237,11 @@ def experiments_real(dataset=None):
params.update({'dataset': dataset})
l_params = [params]
l_params = extend_l_params(l_params, 'init_tp', ['msc', 'rnd'])
l_params = extend_l_params(l_params, 'ptn_split', [True, False])
l_params = extend_l_params(l_params, 'sub_dataset', exp_sta.REAL_SUB_DATASETS)
l_params = extend_l_params(l_params, 'nb_lbs', range(5, 12, 2) + range(15, 35, 4))
l_params = extend_l_params(l_params, 'gc_regul', [0., 1e-3, 1e-1, 1e0])
l_params = exp_sta.extend_l_params(l_params, 'sub_dataset', exp_sta.REAL_SUB_DATASETS)
l_params = exp_sta.extend_l_params(l_params, 'init_tp', ['msc', 'rnd'])
l_params = exp_sta.extend_l_params(l_params, 'ptn_split', [True, False])
l_params = exp_sta.extend_l_params(l_params, 'nb_lbs', range(5, 12, 2) + range(15, 35, 4))
l_params = exp_sta.extend_l_params(l_params, 'gc_regul', [0., 1e-3, 1e-1, 1e0])
logger.debug('list params: {}'.format(len(l_params)))
......@@ -313,16 +258,16 @@ def experiments_test():
""" simple test of the experiments
:return:
"""
experiment_pipeline_alpe_showcase()
# experiment_pipeline_alpe_showcase()
params = exp_sta.SYNTH_PARAMS.copy()
params['nb_runs'] = 3
expt = ExperimentALPE(params)
expt.run()
expt.run(n_var='case', v_range=range(params['nb_runs']))
expt_p = ExperimentALPE_mp(params)
expt_p.run()
expt_p.run(n_var='case', v_range=range(params['nb_runs']))
return
......@@ -333,9 +278,9 @@ if __name__ == "__main__":
# test_encoding(atlas, imgs, encoding)
# test_atlasLearning(atlas, imgs, encoding)
# experiments_test()
experiments_test()
experiments_synthetic()
# experiments_synthetic()
# experiments_real()
# experiments_real('1000_imgs_binary')
......
This diff is collapsed.
......@@ -391,7 +391,7 @@ def dataset_load_images(name=DEFAULT_DATASET, path_base=DEFAULT_PATH_APD,
mproc_pool.close()
mproc_pool.join()
else:
logger.debug('running in single threads...')
logger.debug('running in single thread...')
imgs = []
for i, p_im in enumerate(p_imgs):
imgs.append(load_image(p_im))
......@@ -463,7 +463,7 @@ def dataset_export_images(p_out, imgs, names=None, nb_jobs=1):
mproc_pool.close()
mproc_pool.join()
else:
logger.debug('running in single threads...')
logger.debug('running in single thread...')
for i, im in enumerate(imgs):
export_image(p_out, im, names[i])
p_npz = os.path.join(p_out, 'input_images.npz')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment