Commit b7ad72cf authored by Jiri Borovec's avatar Jiri Borovec

repare step & speed-up

parent f5b38f68
......@@ -22,33 +22,64 @@ DEFAULT_UNARY_BACKGROUND = 1
# kladne hodnoty.
def compute_relative_penaly_images_weights(imgs, weights):
# def compute_relative_penaly_images_weights(l_imgs, weights):
# """ compute the relative penalty for all pixel and cjsing each label
# on that particular position
#
# :param l_imgs: [np.array<w, h>]
# :param weights: np.array<nb_imgs, nb_lbs>
# :return: np.array<w, h, nb_lbs>
# """
# logger.info('compute unary cost from images and related weights')
# # weightsIdx = ptn_weight.convert_weights_binary2indexes(weights)
# nb_lbs = weights.shape[1] + 1
# assert len(l_imgs) == weights.shape[0]
# pott_sum = np.zeros(l_imgs[0].shape + (nb_lbs,))
# # extenf the weights by background value 0
# weights_ext = np.append(np.zeros((weights.shape[0], 1)), weights, axis=1)
# # logger.debug(weights_ext)
# imgs = np.array(l_imgs)
# logger.debug('DIMS pott: {}, l_imgs {}, w_bin: {}'.format(pott_sum.shape,
# imgs.shape, weights_ext.shape))
# logger.debug('... walk over all pixels in each image')
# for i in range(pott_sum.shape[0]):
# for j in range(pott_sum.shape[1]):
# # per all images in list
# for k in range(len(l_imgs)):
# cost = abs(weights_ext[k] - l_imgs[k][i,j])
# # logger.debug(cost)
# pott_sum[i, j] += cost
# pott_sum_norm = pott_sum / float(len(l_imgs))
# return pott_sum_norm
def compute_relative_penaly_images_weights(l_imgs, weights):
""" compute the relative penalty for all pixel and cjsing each label
on that particular position
:param imgs: [np.array<w, h>]
:param l_imgs: [np.array<w, h>]
:param weights: np.array<nb_imgs, nb_lbs>
:return: np.array<w, h, nb_lbs>
"""
logger.info('compute unary cost from images and related weights')
# weightsIdx = ptn_weight.convert_weights_binary2indexes(weights)
nb_lbs = weights.shape[1] + 1
assert len(imgs) == weights.shape[0]
pott_sum = np.zeros(imgs[0].shape + (nb_lbs,))
assert len(l_imgs) == weights.shape[0]
pott_sum = np.zeros(l_imgs[0].shape + (nb_lbs,))
# extenf the weights by background value 0
weights_ext = np.append(np.zeros((weights.shape[0], 1)), weights, axis=1)
# logger.debug(weights_ext)
# walkmover all pixels in image
imgs = np.array(l_imgs)
logger.debug('DIMS pott: {}, l_imgs {}, w_bin: {}'.format(pott_sum.shape,
imgs.shape, weights_ext.shape))
logger.debug('... walk over all pixels in each image')
# TODO: make it as matrix ops
for i in range(pott_sum.shape[0]):
for j in range(pott_sum.shape[1]):
# per all images in list
for k in range(len(imgs)):
cost = abs(weights_ext[k] - imgs[k][i,j])
# logger.debug(cost)
pott_sum[i, j] += cost
pott_sum_norm = pott_sum / float(len(imgs))
# make it as matrix ops
img_vals = np.repeat(imgs[:, i, j, np.newaxis],
weights_ext.shape[1], axis=1)
pott_sum[i, j] = np.sum(abs(weights_ext - img_vals), axis=0)
pott_sum_norm = pott_sum / float(len(l_imgs))
return pott_sum_norm
......@@ -224,7 +255,7 @@ def export_visual_atlas(i, out_dir, atlas=None, weights=None, prefix='debug'):
return None
def apd_initialisation(imgs, init_atlas, init_weights, out_dir, out_prefix):
def alpe_initialisation(imgs, init_atlas, init_weights, out_dir, out_prefix):
""" more complex initialisation depending on inputs
:param imgs: [np.array<w, h>]
......@@ -254,7 +285,7 @@ def apd_initialisation(imgs, init_atlas, init_weights, out_dir, out_prefix):
return atlas, w_bins
def apd_update_weights(imgs, atlas):
def alpe_update_weights(imgs, atlas):
""" single iteration of the block coordinate descent algo
:param imgs: [np.array<w, h>]
......@@ -272,7 +303,7 @@ def apd_update_weights(imgs, atlas):
return np.array(w_bins)
def apd_repaire_atlas_weights(imgs, atlas, w_bins, lb_max):
def alpe_repaire_atlas_weights(imgs, atlas, w_bins, lb_max):
"""
:param imgs: [np.array<w, h>]
......@@ -286,7 +317,7 @@ def apd_repaire_atlas_weights(imgs, atlas, w_bins, lb_max):
return atlas, w_bins
def apd_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit):
def alpe_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit):
""" single iteration of the block coordinate descent algo
:param imgs: [np.array<w, h>]
......@@ -313,9 +344,9 @@ def apd_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit):
return atlas_new, step_diff
def apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None,
gc_coef=0.0, thr_step_diff=0.0, max_iter=99,
gc_reinit=True, out_prefix='debug', out_dir=''):
def alpe_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None,
gc_coef=0.0, thr_step_diff=0.0, max_iter=99,
gc_reinit=True, out_prefix='debug', out_dir=''):
""" the main pipeline for block coordinate descent algo with graphcut
:param imgs: [np.array<w, h>]
......@@ -334,8 +365,8 @@ def apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None
assert len(imgs) >= 0
# assert initAtlas is not None or type(max_nb_lbs)==int
# initialise
atlas, w_bins = apd_initialisation(imgs, init_atlas, init_weights,
out_dir, out_prefix)
atlas, w_bins = alpe_initialisation(imgs, init_atlas, init_weights,
out_dir, out_prefix)
lb_max = np.max(atlas)
for i in range(max_iter):
......@@ -343,14 +374,14 @@ def apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None
logger.info('ERROR: the atlas does not contain '
'any label... {}'.format(np.unique(atlas)))
w_bins = apd_update_weights(imgs, atlas)
w_bins = alpe_update_weights(imgs, atlas)
# plt.subplot(221), plt.imshow(atlas, interpolation='nearest')
# plt.subplot(222), plt.imshow(w_bins, aspect='auto')
atlas, w_bins = apd_repaire_atlas_weights(imgs, atlas, w_bins, lb_max)
atlas, w_bins = alpe_repaire_atlas_weights(imgs, atlas, w_bins, lb_max)
# plt.subplot(223), plt.imshow(atlas, interpolation='nearest')
# plt.subplot(224), plt.imshow(w_bins, aspect='auto')
# plt.show()
atlas, step_diff = apd_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit)
atlas, step_diff = alpe_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit)
logger.info('-> iter. #{} with Atlas diff {}'.format(i + 1, step_diff))
export_visual_atlas(i + 1, out_dir, atlas, w_bins, prefix=out_prefix)
......
__author__ = 'Jiri Borovec'
import os, sys, glob
import os
import sys
import time
import numpy as np
import pandas
from skimage import io
import matplotlib.pylab as plt
import matplotlib.gridspec as gridspec
sys.path.append(os.path.abspath(os.path.join('..','..'))) # Add path to root
# import src.ownUtils.toolDataIO as tD
import generate_dataset as gen_data
import dictionary_learning as dl
......@@ -16,19 +13,17 @@ import ptn_weights as ptn_weigth
import logging
logger = logging.getLogger(__name__)
DEFAULT_PATH_OUTPUT = os.path.join('..','..','output')
def experiment_pipeline_alpe(atlas, imgs, encoding):
def experiment_pipeline_alpe(atlas, imgs, encoding, p_out=DEFAULT_PATH_OUTPUT):
init_atlas_org = ptn_dict.initialise_atlas_deform_original(atlas)
init_atlas_rnd = ptn_dict.initialise_atlas_random(atlas.shape, np.max(atlas))
# init_atlas_rnd = ptn_dict.initialise_atlas_random(atlas.shape, np.max(atlas))
init_atlas_msc = ptn_dict.initialise_atlas_mosaic(atlas.shape, np.max(atlas))
init_encode_rnd = ptn_weigth.initialise_weights_random(len(imgs), np.max(atlas))
pOut = os.path.join('..','..','output')
# dl.apd_pipe_atlas_learning_ptn_weights(imgs, initAtlas=init_atlas_msc,
# maxIter=9, reInit=False, outDir=pOut, outPrefix='mosaic')
dl.apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=init_atlas_msc,
max_iter=9, out_dir=pOut, out_prefix='msc')
dl.alpe_pipe_atlas_learning_ptn_weights(imgs, init_atlas=init_atlas_msc,
max_iter=9, out_dir=p_out, out_prefix='defm_msc')
return None
......@@ -50,7 +45,9 @@ def test_simple_case():
for i, (img, w) in enumerate(zip(imgs, ws)):
plt.subplot(gs[0, i + 1]), plt.title('w:{}'.format(w))
plt.imshow(img, cmap='gray', interpolation='nearest')
t = time.time()
uc = dl.compute_relative_penaly_images_weights(imgs, np.array(ws))
logger.debug('elapsed TIME: {}'.format(time.time() - t))
res = dl.estimate_atlas_graphcut_general(imgs, np.array(ws), 0.)
plt.subplot(gs[0, -1]), plt.title('result')
plt.imshow(res, cmap=cm, interpolation='nearest'), plt.colorbar()
......@@ -67,8 +64,11 @@ def test_simple_case():
def main():
atlas = gen_data.dataset_create_atlas()
# plt.imshow(atlas)
imgs = gen_data.dataset_load_images()
# imgs = gen_data.dataset_load_images()
imgs = gen_data.dataset_load_images('datasetBinary_deform')
# plt.imshow(imgs[0])
encoding = gen_data.dataset_load_weights()
# logger.info('encoding: {}'.format(encoding))
......
......@@ -88,7 +88,6 @@ def draw_ellipse(ratio=0.1, img=None, clr=255, im_size=DEFAULT_IM_SIZE):
img[x, y] = clr
# img = transform.rotate(img, angle=random.randint(0, 180), center=c, order=0, cval=img[0,0])
# img = transform.rotate(img, angle=random.randint(0, 180), center=np.array(imSize)/2, order=0, cval=img[0,0])
# TODO: add rotation
return img
......
......@@ -164,7 +164,7 @@ def reinit_atlas_likely_patterns(imgs, w_bins, atlas, lb_max=None):
atlas = insert_new_pattern(imgs, imgs_rc, atlas, l)
logger.debug('w_bins before: {}'.format(np.sum(w_bins[:, l_w])))
w_bins[:, l_w] = ptn_weight.weights_label_atlas_overlap_threshold(imgs,
atlas, l, 1e-6)
atlas, l, 1e-3)
logger.debug('w_bins after: {}'.format(np.sum(w_bins[:, l_w])))
return atlas, w_bins
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment