function image = postprocessing(image, num_strips, mask_width, palette)


% image = postprocessing(image, num_strips, mask_width, palette)
%
% 1. paint the detected edges according to the display mode set at the beginning
% 2. endpoint location to make the traced edges more accurate
%
%
% INPUT:
% -------------------------------------------------------------------------
% image       		- a structure containing the input image and its detected edges
% num_strips      	- number of strips used in the detection stage
% mask_width		- mask width used for extracting pixel responses
% palette		- a structure how to represent the detected edges
%
% OUTPUT:
% -------------------------------------------------------------------------
% image     		- modified structure containing a ready-to-be-visualized output
%
%
% Written by Yi-Qing Wang, 2016

	if isfield(palette, 'wo')
		wo = palette.wo;
	else
		wo = 0;
	end

	for strip_idx = 1:num_strips
		for s = 1:size(image.curve_coordinates.left, 2)
			data = image.curve_coordinates.left{strip_idx, s};
			lcurves = [];
			if ~isempty(data)
			if ~isempty(data{1})
			lcurves = aggregate(data, image, mask_width);
			for r = 1:size(lcurves.rowids, 1)
				for l = 1:size(lcurves.rowids, 2)
					row = lcurves.rowids(r, l);
					if row < 1
						continue;
					else
						col = lcurves.colids(l);
						if sum(image.sketch(row-1:row+1, col)) < 1
						image.sketch(row, col) = 255;
						% allow to disable certain lines
						if isfield(palette, 'track')
						image.background(row-wo:row+wo, col, 1) = palette.track(1);
						image.background(row-wo:row+wo, col, 2) = palette.track(2);
						image.background(row-wo:row+wo, col, 3) = palette.track(3);
						end
						end
					end
				end
			end
			end
			end

			data = image.curve_coordinates.right{strip_idx, s};
			rcurves = [];
			if ~isempty(data)
			if ~isempty(data{1})
			rcurves = aggregate(data, image, mask_width);
			for r = 1:size(rcurves.rowids, 1)
				for l = 1:size(rcurves.rowids, 2)
					row = rcurves.rowids(r, l);
					if row < 1
						continue;
					else
						col = rcurves.colids(l);
						if sum(image.sketch(row-1:row+1, col)) < 1
						image.sketch(row, col) = 255;
						% allow to disable certain lines
						if isfield(palette, 'track')
						image.background(row-wo:row+wo, col, 1) = palette.track(1);
						image.background(row-wo:row+wo, col, 2) = palette.track(2);
						image.background(row-wo:row+wo, col, 3) = palette.track(3);
						end
						end
					end
				end
			end
			end
			end

			data = image.curve_coordinates.detected{strip_idx, s};
			if ~isempty(data)
				% perform endpoint location for the detected edges without a viable extension
				if isfield(palette, 'track')
				% if asked to display only the detection result, then no trimming
				data = loc_detect(data, lcurves, rcurves, image);
				end
				for r = 1:size(data.rowids, 1)
				for c = 1:size(data.rowids, 2)
					if data.rowids(r, c) > 0
					if sum(image.sketch(data.rowids(r, c)-1:data.rowids(r,c)+1, data.colids(c))) < 1
					image.sketch(data.rowids(r, c), data.colids(c)) = 255;
					% allow to disable certain lines
					if isfield(palette, 'detect')
					image.background(data.rowids(r, c)-wo:data.rowids(r, c)+wo, data.colids(c), 1) = palette.detect(1);
					image.background(data.rowids(r, c)-wo:data.rowids(r, c)+wo, data.colids(c), 2) = palette.detect(2);
					image.background(data.rowids(r, c)-wo:data.rowids(r, c)+wo, data.colids(c), 3) = palette.detect(3);
					end
					end
					end
				end
				end
			end
		end
	end
end

function data = loc_detect(data, lcurves, rcurves, image)

	if ~isempty(lcurves)
	[sl, ~] = sort(lcurves.rowids(:, 1));
	else
	sl = [];
	end
	if ~isempty(rcurves)
	[sr, ~] = sort(rcurves.rowids(:, 1));
	else
	sr = [];
	end

	[scl, clids] = sort(data.rowids(:, 1));
	[scr, crids] = sort(data.rowids(:, end));

	length = numel(data.colids);

	start = 1;
	for r = 1:numel(scl)
		match = true;
		if start <= numel(sl)
			if abs(scl(r) - sl(start)) > 1
				% endpoint location
				match = false;
			else
				start = start + 1;
			end
		else
			% endpoint location
			match = false;
		end
		if ~match
			idx = clids(r);

			unit = (length+1)*0.5;
			rows = data.rowids(idx, 1:unit);
			cols = data.colids(1:unit);
			noisy = image.pixel_responses(sub2ind(size(image.pixel_responses), rows, cols));

			strength = ((cumsum(noisy)).^2)./(1:unit);
			[~, maxid] = max(strength);
			endloc = unit - maxid;
			data.rowids(idx, 1:endloc) = 0;  % comment out this if you don't want to trim detections
		end
	end

	start = 1;
	for r = 1:numel(scr)
		match = true;
		if start <= numel(sr)
			if abs(scr(r) - sr(start)) > 1
				% endpoint location
				match = false;
			else
				start = start + 1;
			end
		else
			% endpoint location
			match = false;
		end
		if ~match
			idx = crids(r);

			unit = (length+1)*0.5;
			rows = data.rowids(idx, unit:end);
			cols = data.colids(unit:end);

			noisy = image.pixel_responses(sub2ind(size(image.pixel_responses), rows, cols));

			strength = ((cumsum(noisy)).^2)./(1:unit);
			[~, maxid] = max(strength);
			endloc = unit - 1 + maxid;
			data.rowids(idx, endloc:end) = 0; % comment out this if you don't want to trim detections
		end
	end

end

function curves = aggregate(data, image, mask_width)

	% merge curves

	num_curves = size(data{1}.rowids, 1);
	unit = size(data{1}.rowids, 2);
	full_length = unit * numel(data);
	curves.rowids = zeros(num_curves, full_length);
	curves.colids = zeros(1, full_length);

	curves.rowids(:, 1:unit) = data{1}.rowids;
	curves.colids(1:unit) = data{1}.colids;

	edge_sign = data{1}.edge_sign;

	for k = 2:numel(data)
		if ~isempty(data{k})
            LEN = numel(data{k}.colids);
			curves.colids((k-1)*unit+1:(k-1)*unit+LEN) = data{k}.colids;
			endpoints = curves.rowids(:, (k-1)*unit);
			is_alive = endpoints > 0;
			assert(sum(is_alive) >= size(data{k}.rowids, 1));
			% careful with what to extend next
			for l = 1:size(data{k}.rowids, 1)
				idx = find(is_alive & (edge_sign == data{k}.edge_sign(l)));
				[~, minid] = min(abs(endpoints(idx) - data{k}.rowids(l, 1)));
                LEN = size(data{k}.rowids, 2);
				curves.rowids(idx(minid), (k-1)*unit+1:(k-1)*unit+LEN) = data{k}.rowids(l, :);
				% smooth the connections
				curves.rowids(idx(minid), (k-1)*unit-1:(k-1)*unit) = floor(mean(curves.rowids(idx(minid), (k-1)*unit-1:(k-1)*unit)));
				is_alive(idx(minid)) = false;
			end
		end
	end


	% locate the endpoints using a maximum likelihood procedure
	for k = 1:num_curves
		idx = curves.rowids(k, :) > 0;
		rows = curves.rowids(k, idx);
		cols = curves.colids(idx);

		rows = rows(end-unit+1:end);
		cols = cols(end-unit+1:end);
		noisy = image.pixel_responses(sub2ind(size(image.pixel_responses), rows - mask_width, cols));

		strength = ((cumsum(noisy)).^2)./(1:unit);
		[~, maxid] = max(strength);
		endloc = sum(idx) - (unit - maxid);
		curves.rowids(k, endloc:end) = 0;
	end

end
