Commit f5b38f68 authored by Jiri Borovec's avatar Jiri Borovec

repaire step & bug fix

parent 9d3929b1
import os
import numpy as np
import pattern_disctionary as ptn_dict
import pattern_weights as ptn_weight
import ptn_disctionary as ptn_dict
import ptn_weights as ptn_weight
import similarity_metric as sim_metric
import matplotlib.pyplot as plt
import logging
......@@ -269,10 +269,10 @@ def apd_update_weights(imgs, atlas):
for img in imgs]
# add once for patterns that are not used at all
# w_bins = ptn_weight.fill_empty_patterns(np.array(w_bins))
return w_bins
return np.array(w_bins)
def apd_repaire_atlas_weights(imgs, atlas, w_bins):
def apd_repaire_atlas_weights(imgs, atlas, w_bins, lb_max):
"""
:param imgs: [np.array<w, h>]
......@@ -281,13 +281,12 @@ def apd_repaire_atlas_weights(imgs, atlas, w_bins):
:return: np.array<w, h>, np.array<nb_imgs, nb_lbs>
"""
logger.debug('... perform repairing')
atlas = ptn_dict.reinit_atlas_likely_patterns(imgs, w_bins, atlas)
w_bins = [ptn_weight.weights_image_atlas_overlap_partial(img, atlas)
for img in imgs]
# reinit empty
atlas, w_bins = ptn_dict.reinit_atlas_likely_patterns(imgs, w_bins, atlas, lb_max)
return atlas, w_bins
def apd_update_atlas(imgs, atlas, w_bins, gc_coef, gc_reinit):
def apd_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit):
""" single iteration of the block coordinate descent algo
:param imgs: [np.array<w, h>]
......@@ -302,14 +301,16 @@ def apd_update_atlas(imgs, atlas, w_bins, gc_coef, gc_reinit):
w_bins = np.array(w_bins)
# update atlas
logger.debug('... perform Atlas estimation')
# atlasNew = estimate_atlas_graphcut_simple(imgs, w_bins)
# atlas_new = estimate_atlas_graphcut_simple(imgs, w_bins)
if gc_reinit:
atlasNew = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef, atlas)
atlas_new = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef, atlas)
else:
atlasNew = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef)
atlas_new = estimate_atlas_graphcut_general(imgs, w_bins, gc_coef)
atlas_new = ptn_dict.atlas_split_indep_ptn(atlas_new, lb_max)
step_diff = sim_metric.compare_atlas_adjusted_rand(atlas, atlasNew)
return atlasNew, step_diff
step_diff = sim_metric.compare_atlas_adjusted_rand(atlas, atlas_new)
return atlas_new, step_diff
def apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None,
......@@ -335,6 +336,7 @@ def apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None
# initialise
atlas, w_bins = apd_initialisation(imgs, init_atlas, init_weights,
out_dir, out_prefix)
lb_max = np.max(atlas)
for i in range(max_iter):
if len(np.unique(atlas)) == 1:
......@@ -342,8 +344,13 @@ def apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=None, init_weights=None
'any label... {}'.format(np.unique(atlas)))
w_bins = apd_update_weights(imgs, atlas)
atlas, w_bins = apd_repaire_atlas_weights(imgs, atlas, w_bins)
atlas, step_diff = apd_update_atlas(imgs, atlas, w_bins, gc_coef, gc_reinit)
# plt.subplot(221), plt.imshow(atlas, interpolation='nearest')
# plt.subplot(222), plt.imshow(w_bins, aspect='auto')
atlas, w_bins = apd_repaire_atlas_weights(imgs, atlas, w_bins, lb_max)
# plt.subplot(223), plt.imshow(atlas, interpolation='nearest')
# plt.subplot(224), plt.imshow(w_bins, aspect='auto')
# plt.show()
atlas, step_diff = apd_update_atlas(imgs, atlas, w_bins, lb_max, gc_coef, gc_reinit)
logger.info('-> iter. #{} with Atlas diff {}'.format(i + 1, step_diff))
export_visual_atlas(i + 1, out_dir, atlas, w_bins, prefix=out_prefix)
......
......@@ -10,25 +10,25 @@ sys.path.append(os.path.abspath(os.path.join('..','..'))) # Add path to root
# import src.ownUtils.toolDataIO as tD
import generate_dataset as gen_data
import dictionary_learning as dl
import pattern_disctionary as ptn_dict
import pattern_weights as ptn_weigth
import ptn_disctionary as ptn_dict
import ptn_weights as ptn_weigth
import logging
logger = logging.getLogger(__name__)
def experiment_pipeline_alpe(atlas, imgs, encoding):
initAtlas_org = ptn_dict.initialise_atlas_deform_original(atlas)
initAtlas_rnd = ptn_dict.initialise_atlas_random(atlas.shape, np.max(atlas))
initAtlas_msc = ptn_dict.initialise_atlas_mosaic(atlas.shape, np.max(atlas))
initEncode_rnd = ptn_weigth.initialise_weights_random(len(imgs), np.max(atlas))
init_atlas_org = ptn_dict.initialise_atlas_deform_original(atlas)
init_atlas_rnd = ptn_dict.initialise_atlas_random(atlas.shape, np.max(atlas))
init_atlas_msc = ptn_dict.initialise_atlas_mosaic(atlas.shape, np.max(atlas))
init_encode_rnd = ptn_weigth.initialise_weights_random(len(imgs), np.max(atlas))
pOut = os.path.join('..','..','output')
# dl.apd_pipe_atlas_learning_ptn_weights(imgs, initAtlas=initAtlas_msc,
# dl.apd_pipe_atlas_learning_ptn_weights(imgs, initAtlas=init_atlas_msc,
# maxIter=9, reInit=False, outDir=pOut, outPrefix='mosaic')
dl.apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=initAtlas_rnd,
max_iter=9, out_dir=pOut, out_prefix='rnd')
dl.apd_pipe_atlas_learning_ptn_weights(imgs, init_atlas=init_atlas_msc,
max_iter=9, out_dir=pOut, out_prefix='msc')
return None
......
......@@ -117,11 +117,12 @@ def extract_image_largest_element(im):
:return: np.array<w, h> of values {0, 1}
"""
labeled, nbObjects = ndimage.label(im)
areas = [(j, np.sum(labeled==j)) for j in np.unique(labeled)]
areas = [(j, np.sum(labeled == j)) for j in np.unique(labeled)]
areas = sorted(areas, key=lambda x: x[1], reverse=True)
logger.debug('... elements area: {}'.format(areas))
if len(areas) > 1:
im[:] = 0
# skip largest, assuming to be background
im[labeled==areas[1][0]] = 1
return im
......
import numpy as np
import matplotlib.pyplot as plt
import generate_dataset as data
import ptn_weights as ptn_weight
from skimage import morphology
from scipy import ndimage
import logging
logger = logging.getLogger(__name__)
......@@ -14,7 +17,7 @@ def initialise_atlas_random(im_size, max_lb):
"""
logger.debug('initialise atlas {} as random labeling'.format(im_size))
nb_lbs = max_lb + 1
im = np.random.randint(0, nb_lbs, im_size)
im = np.random.randint(1, nb_lbs, im_size)
return np.array(im, dtype=np.int)
......@@ -27,11 +30,11 @@ def initialise_atlas_mosaic(im_size, max_lb):
:return: np.array<w, h>
"""
logger.debug('initialise atlas {} as grid labeling'.format(im_size))
nb_lbs = max_lb + 1
nb_lbs = max_lb
block = np.ones(np.ceil(im_size / np.array(nb_lbs, dtype=np.float)))
logger.debug('block size is {}'.format(block.shape))
for l in range(nb_lbs):
idx = np.random.permutation(range(nb_lbs))
idx = np.random.permutation(range(1, nb_lbs + 1))
for k in range(nb_lbs):
b = block.copy() * idx[k]
if k == 0:
......@@ -58,17 +61,17 @@ def initialise_atlas_deform_original(atlas):
return np.array(res, dtype=np.int)
def reconstruct_samples(atlas, weights):
def reconstruct_samples(atlas, w_bin):
""" create reconstruction of binary images according given atlas and weights
:param atlas: np.array<w, h> input atlas
:param weihts: np.array<nb_imgs, nb_lbs>
:return: [np.array<w, h>] * nb_imgs
"""
w_bin = np.array(weights)
weights_ext = np.append(np.zeros((w_bin.shape[0], 1)), weights, axis=1)
imgs = [None] * len(weights)
for i, w in enumerate(weights_ext):
# w_bin = np.array(weights)
w_bin_ext = np.append(np.zeros((w_bin.shape[0], 1)), w_bin, axis=1)
imgs = [None] * w_bin.shape[0]
for i, w in enumerate(w_bin_ext):
imgs[i] = np.asarray(w)[np.asarray(atlas)]
return imgs
......@@ -90,46 +93,112 @@ def get_prototype_new_pattern(imgs, imgs_rc, diffs, atlas):
im_diff = morphology.closing(im_diff, morphology.disk(1))
# find largesr connected component
ptn = data.extract_image_largest_element(im_diff)
atlas_diff = im_diff - (atlas > 0)
ptn = atlas_diff > 0
return ptn
def insert_new_pattern(imgs, w_bin, atlas, lb):
atlas_diff = np.logical_and(ptn == True, atlas == 0)
ptn_res = atlas_diff > 0
# plt.figure()
# plt.subplot(231), plt.imshow(im_diff), plt.colorbar()
# plt.title('im_diff {}'.format(np.unique(im_diff)))
# plt.subplot(232), plt.imshow(ptn), plt.colorbar()
# plt.title('ptn {}'.format(np.unique(ptn)))
# plt.subplot(233), plt.imshow(atlas), plt.colorbar()
# plt.title('atlas {}'.format(np.unique(atlas)))
# plt.subplot(234), plt.imshow((atlas > 0)), plt.colorbar()
# plt.title('atlas {}'.format(np.unique((atlas > 0))))
# plt.subplot(235), plt.imshow(atlas_diff), plt.colorbar()
# plt.title('atlas_diff {}'.format(np.unique(atlas_diff)))
# plt.subplot(236), plt.imshow(ptn_res), plt.colorbar()
# plt.title('ptn_res {}'.format(np.unique(ptn_res)))
# plt.show()
return ptn_res
def insert_new_pattern(imgs, imgs_rc, atlas, lb):
""" with respect to atlas empty spots inset new patterns
:param imgs: [np.array<w, h>] list of input images
:param w_bin: np.array<nb_imgs, nb_lbs>
:param imgs_rc: [np.array<w, h>]
:param atlas: np.array<w, h>
:param lb: int
:return: np.array<w, h> updated atlas
"""
imgs_rc = reconstruct_samples(atlas, w_bin)
diffs = []
# count just positive difference
for im, im_rc in zip(imgs, imgs_rc):
diff = np.sum((im - im_rc) > 0)
diffs.append(diff)
im_ptn = get_prototype_new_pattern(imgs, imgs_rc, diffs, atlas)
# logger.debug('new im_ptn: {}'.format(np.sum(im_ptn) / np.prod(im_ptn.shape)))
# plt.imshow(im_ptn), plt.title('im_ptn'), plt.show()
atlas[im_ptn == True] = lb
logger.debug('area of new pattern is {}'.format(np.sum(atlas == lb)))
return atlas
def reinit_atlas_likely_patterns(imgs, w_bin, atlas):
def reinit_atlas_likely_patterns(imgs, w_bins, atlas, lb_max=None):
""" walk and find all all free labels and try to reinit them by new patterns
:param imgs: [np.array<w, h>] list of input images
:param w_bin: np.array<nb_imgs, nb_lbs>
:param w_bins: np.array<nb_imgs, nb_lbs>
:param atlas: np.array<w, h>
:return:
:return: np.array<w, h>, np.array<nb_imgs, nb_lbs>
"""
# find empty patterns
sums = np.sum(w_bin, axis=0)
logger.debug('IN > sum over weights: {}'.format(sums))
for i, v in enumerate(sums):
if v == 0:
atlas = insert_new_pattern(imgs, w_bin, atlas, i)
return atlas
if lb_max is None:
lb_max = max(np.max(atlas), w_bins.shape[1])
else:
logger.debug('compare w_bin {} to max {}'.format(w_bins.shape, lb_max))
for i in range(w_bins.shape[1], lb_max):
logger.debug('adding disappeared weigh column {}'.format(i))
w_bins = np.append(w_bins, np.zeros((w_bins.shape[0], 1)), axis=1)
w_bin_ext = np.append(np.zeros((w_bins.shape[0], 1)), w_bins, axis=1)
logger.debug('IN > sum over weights: {}'.format(np.sum(w_bin_ext, axis=0)))
# add one while indexes does not cover 0 - bg
logger.debug('total nb labels: {}'.format(lb_max))
for l in range(1, lb_max + 1):
l_w = l - 1
w_sum = np.sum(w_bins[:, l_w])
logger.debug('reinit lb: {} with weight sum {}'.format(l, w_sum))
if w_sum > 0:
continue
imgs_rc = reconstruct_samples(atlas, w_bins)
atlas = insert_new_pattern(imgs, imgs_rc, atlas, l)
logger.debug('w_bins before: {}'.format(np.sum(w_bins[:, l_w])))
w_bins[:, l_w] = ptn_weight.weights_label_atlas_overlap_threshold(imgs,
atlas, l, 1e-6)
logger.debug('w_bins after: {}'.format(np.sum(w_bins[:, l_w])))
return atlas, w_bins
def atlas_split_indep_ptn(atlas, lb_max):
""" split independent patterns labeled equally
:param atlas: np.array<w, h>
:param lb_max: int
:return:
"""
l_ptns = []
for l in np.unique(atlas):
labeled, nb_objects = ndimage.label(atlas == l)
logger.debug('for lb {} detected #{}'.format(l, nb_objects))
ptn = [(labeled == j) for j in np.unique(labeled)]
# skip the largest one assuming to be background
l_ptns += sorted(ptn, key=lambda x: np.sum(x), reverse=True)[1:]
l_ptns = sorted(l_ptns, key=lambda x: np.sum(x), reverse=True)
logger.debug('list of all areas {}'.format([np.sum(p) for p in l_ptns]))
atlas_new = np.zeros(atlas.shape, dtype=np.int)
# take just lb_max largest elements
for i, ptn in enumerate(l_ptns[:lb_max]):
l = i + 1
logger.debug('pattern #{} area {}'.format(l, np.sum(ptn)))
# plt.subplot(1,lb_max,l), plt.imshow(ptn), plt.colorbar()
atlas_new[ptn] = l
# plt.figure()
# plt.subplot(121), plt.imshow(atlas), plt.colorbar()
# plt.subplot(122), plt.imshow(atlas_new), plt.colorbar()
# plt.show()
logger.debug('atlas unique {}'.format(np.unique(atlas_new)))
return atlas_new
if __name__ == "__main__":
......
import numpy as np
import generate_dataset as data
from skimage import morphology
import logging
logger = logging.getLogger(__name__)
def initialise_weights_random(nb_imgs, nb_lbs, ratio_sel=0.2):
"""
:param nb_imgs: int, numer of all images
:param nb_lbs: int, numer of all avalaible labels
:param ratio_sel: float<0, 1> defining how many should be set on,
1 means all and 0 means none
:return: np.array<nb_imgs, nb_lbs>
"""
logger.debug('initialise weights for {} images and {} labels '
'as random selection'.format(nb_imgs, nb_lbs))
prob = np.random.random((nb_imgs, nb_lbs))
weights = np.zeros_like(prob)
weights[prob <= ratio_sel] = 1
return weights
def convert_weights_binary2indexes(weights):
""" convert binary matrix oof weights to list of indexes o activated ptns
:param weights: np.array<nb_imgs, nb_lbs>
:return: [[int, ...]] * nb_imgs
"""
logger.debug('convert binary weights {} '
'to list of indexes with True'.format(weights.shape))
# if type(weights)==np.ndarray: weights = weights.tolist()
w_idx = [None] * weights.shape[0]
for i in range(weights.shape[0]):
# find postions equal 1
# vec = [j for j in range(weights.shape[1]) if weights[i,j]==1]
vec = np.where(weights[i,:] == 1)[0]
w_idx[i] = vec +1
# idxs = np.where(weights == 1)
# for i in range(weights.shape[0]):
# w_idx[i] = idxs[1][idxs[0]==i] +1
return w_idx
def weights_image_atlas_overlap_major(img, atlas):
"""
:param img: np.array<w, h>
:param atlas: np.array<w, h>
:return: [int] * nb_lbs of values {0, 1}
"""
# logger.debug('weights input image according given atlas')
weights = weights_image_atlas_overlap_threshold(img, atlas, 0.5)
return weights
def weights_image_atlas_overlap_partial(img, atlas):
"""
:param img: np.array<w, h>
:param atlas: np.array<w, h>
:return: [int] * nb_lbs of values {0, 1}
"""
# logger.debug('weights input image according given atlas')
lbs = np.unique(atlas).tolist()
weights = weights_image_atlas_overlap_threshold(img, atlas, (1. / len(lbs)))
return weights
def weights_image_atlas_overlap_threshold(img, atlas, thr=0.5):
""" estimate what patterns are activated with given atlas and input image
compute overlap matrix and eval nr of overlapping and non pixels and threshold
:param img: np.array<w, h>
:param atlas: np.array<w, h>
:param thr: float, represent the ration between overlapping and non pixels
:return: [int] * nb_lbs of values {0, 1}
"""
# logger.debug('weights input image according given atlas')
# simple weight
lbs = np.unique(atlas).tolist()
# logger.debug('weights image by atlas with labels: {}'.format(lbs))
if 0 in lbs: lbs.remove(0)
weight = [0] * np.max(lbs)
for l in lbs:
equal = np.sum(img[atlas==l])
total = np.sum(atlas==l)
score = equal / float(total)
if score >= thr:
weight[l-1] = 1
return weight
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
import numpy as np
import generate_dataset as data
from skimage import morphology
import logging
logger = logging.getLogger(__name__)
def initialise_weights_random(nb_imgs, nb_lbs, ratio_sel=0.2):
"""
:param nb_imgs: int, numer of all images
:param nb_lbs: int, numer of all avalaible labels
:param ratio_sel: float<0, 1> defining how many should be set on,
1 means all and 0 means none
:return: np.array<nb_imgs, nb_lbs>
"""
logger.debug('initialise weights for {} images and {} labels '
'as random selection'.format(nb_imgs, nb_lbs))
prob = np.random.random((nb_imgs, nb_lbs))
weights = np.zeros_like(prob)
weights[prob <= ratio_sel] = 1
return weights
def convert_weights_binary2indexes(weights):
""" convert binary matrix oof weights to list of indexes o activated ptns
:param weights: np.array<nb_imgs, nb_lbs>
:return: [[int, ...]] * nb_imgs
"""
logger.debug('convert binary weights {} '
'to list of indexes with True'.format(weights.shape))
# if type(weights)==np.ndarray: weights = weights.tolist()
w_idx = [None] * weights.shape[0]
for i in range(weights.shape[0]):
# find postions equal 1
# vec = [j for j in range(weights.shape[1]) if weights[i,j]==1]
vec = np.where(weights[i, :] == 1)[0]
w_idx[i] = vec + 1
# idxs = np.where(weights == 1)
# for i in range(weights.shape[0]):
# w_idx[i] = idxs[1][idxs[0]==i] +1
return w_idx
def weights_image_atlas_overlap_major(img, atlas):
"""
:param img: np.array<w, h>
:param atlas: np.array<w, h>
:return: [int] * nb_lbs of values {0, 1}
"""
# logger.debug('weights input image according given atlas')
weights = weights_image_atlas_overlap_threshold(img, atlas, 0.5)
return weights
def weights_image_atlas_overlap_partial(img, atlas):
"""
:param img: np.array<w, h>
:param atlas: np.array<w, h>
:return: [int] * nb_lbs of values {0, 1}
"""
# logger.debug('weights input image according given atlas')
lbs = np.unique(atlas).tolist()
weights = weights_image_atlas_overlap_threshold(img, atlas, (1. / len(lbs)))
return weights
def weights_image_atlas_overlap_threshold(img, atlas, thr=0.5):
""" estimate what patterns are activated with given atlas and input image
compute overlap matrix and eval nr of overlapping and non pixels and threshold
:param img: np.array<w, h>
:param atlas: np.array<w, h>
:param thr: float, represent the ration between overlapping and non pixels
:return: [int] * nb_lbs of values {0, 1}
"""
# logger.debug('weights input image according given atlas')
# simple weight
lbs = np.unique(atlas).tolist()
# logger.debug('weights image by atlas with labels: {}'.format(lbs))
if 0 in lbs: lbs.remove(0)
weight = [0] * np.max(lbs)
for l in lbs:
equal = np.sum(img[atlas == l])
total = np.sum(atlas == l)
score = equal / float(total)
if score >= thr:
weight[l - 1] = 1
return weight
def weights_label_atlas_overlap_threshold(imgs, atlas, lb, thr=1e-3):
""" estimate what patterns are activated with given atlas and input image
compute overlap matrix and eval nr of overlapping and non pixels and threshold
:param imgs: [np.array<w, h>]
:param atlas: np.array<w, h>
:param lb: int
:param thr: float, represent the ration between overlapping and non pixels
:return: np.array<nb_imgs> of values {0, 1}
"""
weight = [0] * len(imgs)
for i, img in enumerate(imgs):
equal = np.sum(img[atlas == lb])
total = np.sum(atlas == lb)
score = equal / float(total)
if score >= thr:
weight[i] = 1
return np.array(weight)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger.info('DONE')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment