% a class which performs edge estimate extension once edges are detected
% the key method implemented here is full_track()

% Written by Yi-Qing Wang, 2016

classdef EdgeTracker<handle

	properties

		noisy_input;		% noisy sub-image on the right hand side of
					% the detection strip
		pixel_responses;

		fire_stats;		% output of CurveDetector.multiscale_detector()

		image_size;

		b_bounds;		% bounds used for coefficient quantization
		h_bound;		% bound for the derivative of order r+1

		markov_bounds;

		max_num_lps;
		estimated_sigma;

		quantizer;
		mask_width;

		chi_test;
		variance_test;
		num_margins;		% a parameter which may allow to relax a bit the full covering assumption

		defaults;		% ranges for the higher order coefs
		detection_max_scale;

		interp_scale;
		track_scale;		% adjust according to fire_stats

		trackFalsePositive;
		track_threshold;
		scale_threshold;	% below which only the constant component restrains
					% the search space of the adjacent tracking strip

		full_controls;
		full_coefs;

	end

	methods

		% constructor
		function obj = EdgeTracker(noisy_input, pixel_responses, fire_stats, params)

			if nargin == 4

				obj.fire_stats = fire_stats;

				obj.estimated_sigma = fire_stats.estimated_sigma;

				% debugging
				if obj.estimated_sigma < 1e-4
					obj.estimated_sigma = 1;
				end

				obj.mask_width = params.mask_width;
				% no noise normalizing step: pixel_responses and noisy_input have the same noise level

				obj.noisy_input = noisy_input;
				obj.pixel_responses = pixel_responses;

				obj.image_size = params.image_size;
				obj.h_bound = params.h_bound;
				obj.b_bounds = params.b_bounds;
				obj.quantizer = params.quantizer;
				obj.interp_scale = params.interp_scale;
				obj.trackFalsePositive = params.trackFalsePositive;

				obj.track_scale = max(fire_stats.scale + params.scale_offset, obj.interp_scale);	% an expedient choice: lower bound the tracking scale
				obj.defaults = default_coef_range(obj, obj.track_scale);

				obj.variance_test = params.variance_test;
				obj.chi_test = params.chi_test;

				ndegrees = numel(obj.b_bounds);
				coef_range = cell(1, ndegrees);
				for degree = 1:ndegrees
					index = ndegrees - degree + 1;
					max_coef = obj.defaults.max_coef(index);
					coef_range{index} = -max_coef:max_coef;
				end
				[coefs, controls] = generate_templates(obj, coef_range);

				obj.full_controls = controls;
				obj.full_coefs = coefs;

				% enforce continuity at the risk of failing to extending the curve because of a negative shock
				obj.num_margins = params.margin_size;
				obj.track_threshold = sqrt(2*log((obj.num_margins*2+1)*size(controls, 1)/obj.trackFalsePositive));
				obj.track_threshold = obj.track_threshold * obj.estimated_sigma;

				obj.markov_bounds = [4, 4, 1];
				obj.detection_max_scale = params.max_scale;

				obj.max_num_lps = params.max_num_lps;

			else
				error('4 input arguments expected for EdgeTracker');
			end

		end

		% default setting for the higher order coefficients
		function defaults = default_coef_range(obj, test_scale)

			tracking_length = 2^(test_scale-1);
			relength = tracking_length/obj.image_size;

			ndegrees = numel(obj.b_bounds);
			max_coef = zeros(1, ndegrees);

			for degree = 1:ndegrees

				% obj.b_bounds = [b_r, ..., b_1]
				index = ndegrees - degree + 1;
				multiplier = obj.b_bounds(index) * obj.quantizer * obj.image_size;
				max_coef(1, index) = floor(0.5 + (relength^degree)*multiplier);

			end

			defaults.max_coef = max_coef;
			defaults.max_reach = sum(max_coef)/obj.quantizer;
			defaults.max_slope = ((ndegrees:-1:1)*max_coef')/(obj.quantizer*tracking_length);

		end


		function vals = eval_polynomials(~, coefs, nodes)

			% evaluate the polynomials over the nodes
			% a row in coefs represents the vector [a_r, ..., a_0]

			L = numel(nodes);
			vals = coefs(:, 1) * ones(1, L);

			% valt = vals;
			% for k = 2:size(coefs, 2)
			% 	valt = valt.*repmat(nodes, size(valt, 1), 1) + repmat(coefs(:, k), 1, L);
			% end

			for k = 2:size(coefs, 2)
				vals = bsxfun(@plus, coefs(:, k), bsxfun(@times, nodes, vals));
			end

			% assert(isequal(vals, valt));
		end

		% create tracking templates
		function [coefs, controls] = generate_templates(obj, coef_range)

			% cartesian product of higher order coefs

			% http://stackoverflow.com/questions/21895335/generate-a-matrix-containing-all-combinations-of-elements-taken-from-n-vectors
			% a faster alternative to combvec
			en = numel(coef_range);
			combs = cell(1,en);
			[combs{end:-1:1}] = ndgrid(coef_range{end:-1:1});
			combs = cat(en+1, combs{:});
			combs = reshape(combs,[],en)/obj.quantizer;

			% coefs = combvec(coef_range{:})'/obj.quantizer;

			% quantize the polynomial's constant
			constants = kron((0:(obj.quantizer-1))', ones(size(combs, 1), 1))/obj.quantizer;

			% coeft = horzcat(repmat(coefs, obj.quantizer, 1), constants);

			coefs = horzcat(repmat(combs, obj.quantizer, 1), constants);
			% assert(isequal(sortrows(coeft, [1, 2, 3]), sortrows(coefs, [1, 2, 3])));

			ref_scale = obj.track_scale;

			% each row represents a coef combination
			%% fprintf('originally %d tracking templates\n', size(coefs, 1));

			% evaluate the interpolated polynomial values over 0:spacing:1
			num_intervals = 2^(ref_scale - obj.interp_scale);
			nodes = linspace(0, 1, num_intervals + 1);
			controls = floor(eval_polynomials(obj, coefs, nodes));

			% remove the possible duplicates created by discretization
			[controls, uniq_rows] = unique(controls, 'rows');
			coefs = coefs(uniq_rows, :);
			%% fprintf('shrunk to %d tracking templates\n', size(coefs, 1));

		end

		% assume the detected candidate is full-covering, give the range for
		% the curve segment's the tight approximator with the Markov bound
		function coef_lims = locate_unknown(obj, current_scale)

			num_intervals = 2^(max(current_scale-obj.interp_scale, 0));
			ndegrees = numel(obj.b_bounds);
			kkm1 = (ndegrees:-1:2)';
			kkm1 = kkm1.*(kkm1-1);
			cdfs = default_coef_range(obj, current_scale);
			interpolation_error = (cdfs.max_coef(1:ndegrees-1)*kkm1)/(8*obj.quantizer*(num_intervals^2));

			current_length = 2^(current_scale-1);
			relength = current_length/obj.image_size;
			poly_approx_error = obj.h_bound*current_length*(relength^ndegrees) + (ndegrees+1)/(2*obj.quantizer);

			W = obj.mask_width - 1 + interpolation_error + poly_approx_error;
			coef_lims = ceil(W*obj.quantizer * (2.^(ndegrees:-1:0)).*(obj.markov_bounds./factorial(ndegrees:-1:0)));

		end

		% extend the candidates to form the basis for edge tracking polynomials
		function tracked = extend(obj, tracked, current_coef, start_col, restrict, backtrack)

			if backtrack > 0
				start_col = start_col - 2^(obj.track_scale-1) + backtrack;
			end

			oss = SubStrip(obj.pixel_responses, start_col, obj.interp_scale, ...
						obj.track_scale, obj.defaults.max_slope, ...
						obj.variance_test, obj.noisy_input, ...
						obj.mask_width, obj.max_num_lps, obj.estimated_sigma);

			% output
			tracked.start_col = start_col;
			tracked.scale = obj.track_scale;

			num_fires = numel(tracked.edge_sign);
			tracked.response = zeros(num_fires, 1);
			tracked.rowid = zeros(num_fires, 1);
			tracked.controls = zeros(num_fires, 2^(obj.track_scale - obj.interp_scale)+1);

			num_rowls = 2*obj.num_margins + 1;

			if restrict

			current_scale = tracked.scale;
			coef_lims = locate_unknown(obj, current_scale);

			% length of the two successive intervals
			L1 = 2^(current_scale-1);
			L2 = 2^(obj.track_scale-1);

			% coef offset
			ndegrees = numel(obj.b_bounds);
			bounds = [obj.h_bound, obj.b_bounds];
			bounds = min([(bounds(1, 1:ndegrees).*(ndegrees+1:-1:2))*L1; bounds(1, 2:ndegrees+1)*2*obj.image_size]);
			bounds = (bounds .* ((L2/obj.image_size).^(ndegrees:-1:1))) * obj.quantizer;

			% higher order offsets of integer valued quantized coefficients
			coef_offset = floor(1 + (0.5+coef_lims(1:ndegrees)).*((L2/L1).^(ndegrees:-1:1)) + bounds);

			ccand = bsxfun(@times, ((L2/L1).^(ndegrees:-1:0)), current_coef);

			minvals = bsxfun(@max, bsxfun(@minus, ccand(:, 1:2), coef_offset), -obj.defaults.max_coef);
			maxvals = bsxfun(@min, bsxfun(@plus, ccand(:, 1:2), coef_offset), obj.defaults.max_coef);

			end

			% under the full covering assumption, the start position has a well defined range
			if backtrack > 0
			L2 = 2^(obj.track_scale-1);
			ndegrees = numel(obj.b_bounds);
			we = (1-backtrack/L2).^(ndegrees:-1:0);
			positions = (current_coef * we') /obj.quantizer;
			else
			positions = sum(current_coef, 2)/obj.quantizer;
			end

			start_rows = ceil(positions - obj.num_margins);

			for f = 1:num_fires

				% under the full covering assumption, the start position has a well defined range
				start_row = start_rows(f);

				if restrict

				% a candidate row [a_r, ..., a_0]
				% enforce global constraint
				test_range = cell(1, ndegrees);
				for degree = 1:ndegrees
					test_range{degree} = minvals(f, degree):maxvals(f, degree);
				end

				[coefs, controls] = generate_templates(obj, test_range);

				else

				coefs = obj.full_coefs;
				controls = obj.full_controls;

				end


				[coefs, controls] = border_check(obj, coefs, controls, start_row, num_rowls);

				free_memory = f - num_fires + 1;  % only free when f == num_fires
				maxret = spline_integral(oss, tracked.edge_sign(f), controls, ...
							start_row, num_rowls, free_memory);

				if maxret(1) > 0

					tracked.response(f) = maxret(1)/sqrt(2^(obj.track_scale-1)+1);
					tracked.rowid(f) = maxret(2) + obj.mask_width;
					tracked.controls(f, :) = controls(maxret(3), :);
					tracked.coefs(f, :) = coefs(maxret(3), :);

					%% fprintf('\nEdge response is %.02f versus tracking threshold %.02f\n', tracked.response(f), obj.track_threshold);

				end
			end

			idx = find(tracked.response > obj.track_threshold);
			tracked.response = tracked.response(idx);
			tracked.rowid = tracked.rowid(idx);
			tracked.controls = tracked.controls(idx, :);
			tracked.coefs = tracked.coefs(idx, :);
			tracked.edge_sign = tracked.edge_sign(idx);

		end

		% full tracking
		function recorded_curves = full_track(obj, display_every, background)

			% the last term is zero if obj.fire_stats.scale = 0
			start_col = 1 + 2^(obj.detection_max_scale-1) + floor(2^(obj.fire_stats.scale-1));

			tracked = obj.fire_stats;
			% prevent from switching sides when extending edges
			tracked.edge_sign = sign(tracked.signed_response);

			runtimes = floor((size(obj.noisy_input, 2)-start_col)/(2^(obj.track_scale-1)));

			% take care of the image boundary
			remaining = (size(obj.noisy_input, 2)-start_col) - runtimes*(2^(obj.track_scale-1));
			if remaining > 0
				runtimes = runtimes + 1;
			end

			recorded_curves = cell(1, runtimes);
			for run = 1:runtimes

				detected_coefs = tracked.coefs * obj.quantizer;
				detected_coefs(:, end) = detected_coefs(:, end) + (tracked.rowid - obj.mask_width) * obj.quantizer;

				if remaining > 0 && run == runtimes
					backtrack = remaining;
				else
					backtrack = 0;
				end

				tracked = extend(obj, tracked, detected_coefs, start_col, false, backtrack);

				if isempty(tracked.rowid)
					break;
				end

				if display_every < 0
					recorded_curves{run} = visualize(obj, tracked, backtrack);
				else
					display = mod(run, display_every) == 1 || run == runtimes;
					recorded_curves{run} = visualize(obj, tracked, backtrack, background, display);
				end
				start_col = start_col + 2^(obj.track_scale-1);
			end

		end

		% safety check
		function [coefs, controls] = border_check(obj, coefs, controls, start_row, num_rowls)

			% eliminate the templates whose function values are either too large or too small
			% it is a bit hacky and brutal

			minnums = min(controls, [], 2);
			controls = controls(minnums + start_row > 1, :);
			coefs = coefs(minnums + start_row > 1, :);

			maxnums = max(controls, [], 2);
			controls = controls(maxnums < size(obj.pixel_responses, 1) - start_row - num_rowls, :);
			coefs = coefs(maxnums < size(obj.pixel_responses, 1) - start_row - num_rowls, :);

			%% fprintf('further shrink the number of templates to %d\n', size(coefs, 1));
		end


		function courbes = visualize(obj, tracked, backtrack, background, display)

			if nargin == 3
				display = false;
			else
				if length(size(background)) == 2
					fprintf('\nBackground image must be RGB. Discard the input background.\n');
					background = repmat(obj.noisy_input, [1, 1, 3]);
				end
			end

			% background editing

			start_col = tracked.start_col + tracked.disp_col_offset;

			controls = tracked.controls;
			num_controls = size(controls, 1);
			num_contopts = size(controls, 2);

			% controls exclusively interpreted at scale obj.interp_scale
			interp_unitL = 2^(obj.interp_scale-1);
			increment = ones(1, interp_unitL)/interp_unitL;

			num_colids = (num_contopts-1)*interp_unitL;
			colids = start_col : (start_col+num_colids);

			valdiff = diff(controls, 1, 2);
			rowids = floor(cumsum(horzcat(tracked.rowid, kron(valdiff, increment)), 2));

			if backtrack > 0
				rowids = rowids(end-backtrack:end);
				colids = colids(end-backtrack:end);
            end

			courbes.rowids = rowids;
			courbes.colids = colids;
			courbes.edge_sign = tracked.edge_sign;

			if display
				for f = 1:num_controls
				for k = 1:size(rowids, 2)
					background(rowids(f, k)-2:rowids(f, k)+2, colids(k), 1) = 255;
					background(rowids(f, k)-2:rowids(f, k)+2, colids(k), 2) = 0;
					background(rowids(f, k)-2:rowids(f, k)+2, colids(k), 3) = 0;
				end
				end
				figure; imshow(uint8(background));
			end

		end

	end

end
