% a class whose purpose is to provide a method multiscale_detector()
% for detecting polynomial-shaped edges in a noisy input strip

% Written by Yi-Qing Wang, 2016

classdef CurveDetector<handle

	properties
		interp_scale;
		max_scale;

		skipZero;
		estimated_sigma;

		mask_width;
		image_size;

		b_bounds;
		quantizer;

		start_row;		% the first row to anchor a test spline
		num_rowls;		% the number of rows to go through

		falsePositive;
		num_tests;		% the maximum number of tests
		threshold;		% detection threshold

		poly_coefs;		% polynomial coefficients
		poly_controls;		% their interpolated nodes

		noisy_strip;		% original image input
		pixel_responses;	% derived from noisy_strip
		fullstrip;		% a FullStrip object for testing scale > 0

	end

	methods

		% constructor
		function obj = CurveDetector(noisy_strip, params)

			if nargin == 2

				obj.interp_scale = params.interp_scale;
				obj.max_scale = params.max_scale;

				obj.mask_width = params.mask_width;
				obj.image_size = params.image_size; % whether the asymptotic regime fits
				obj.b_bounds = params.b_bounds;
				obj.quantizer = params.quantizer;
				obj.falsePositive = params.falsePositive;

				obj.skipZero = params.skipScaleZero;

				[poly_coefs, poly_controls, max_slopes, max_reaches] = obj.generate_templates();
				obj.poly_coefs = poly_coefs;
				obj.poly_controls = poly_controls;

				% debugging
				% fprintf('\nmax_slopes should be increasing\n');
				% disp(max_slopes);
				% fprintf('max_reaches should be increasing\n');
				% disp(max_reaches);

				assert(length(size(noisy_strip)) == 2);  % only deal with grayscale images
				obj.noisy_strip = noisy_strip;

				% % noise normalizing step 1
				% % pixel_responses directly taken as an input NO GOOD
				mask = [ones(obj.mask_width, 1); -ones(obj.mask_width, 1)];
				pixel_responses = conv2(double(noisy_strip), mask, 'valid')/sqrt(numel(mask));
				% assert(isequal(pixel_responses, pixel_responset));

				estimated_sigma = obj.estimate_noise(params.suggested_sigma, pixel_responses);
				% debugging
				if estimated_sigma > 1e-4
					% noise normalizing step 2
					pixel_responses = pixel_responses/estimated_sigma;
				end
				obj.fullstrip = FullStrip(pixel_responses, obj.interp_scale, obj.max_scale, ...
							max_slopes(end), params.max_num_lps, params.chi_test);
				obj.pixel_responses = pixel_responses;
				obj.estimated_sigma = estimated_sigma;

				obj.start_row = ceil(max_reaches(end)) + 1;
				obj.num_rowls = size(pixel_responses, 1) - 2*obj.start_row;
				if obj.num_rowls < 1
					error('image too small for scale %d detector. Try to reduce max_scale.\n', obj.max_scale);
				end

			else
				error('2 input arguments expected for CurveDetector');
			end

		end


		% multiscale edge detection
		function fire_stats_cell = multiscale_detector(obj)

			% set the detection threshold
			obj.num_tests = obj.num_rowls;	% scale 0
			if obj.skipZero
				obj.num_tests = 0;
			end

			% a rough estimate, asymptotically correct
			for scale = 1:obj.max_scale
				obj.num_tests = obj.num_tests + obj.num_rowls * size(obj.poly_controls{scale}, 1);
			end

			technical_offset = 1; % adjustable
			obj.threshold = technical_offset + sqrt(2*log(obj.num_tests/obj.falsePositive));
			fprintf('detection threshold is set to %.02f with technical offset = %.02f\n', obj.threshold, technical_offset);

			fire_stats.threshold = obj.threshold;

			% actual detection
			% scale 0: examine the strip's middle column

			if ~obj.skipZero

			rowids = obj.start_row:obj.start_row+obj.num_rowls-1;
			central = 2^(obj.max_scale - 1) + 1;
			colids = central*ones(size(rowids));
			[maxval, maxid] = max(abs(obj.pixel_responses(sub2ind(size(obj.pixel_responses), rowids, colids))));
			if maxval > obj.threshold
				fire_stats.scale = 0;
				fire_stats.rowid = maxid + obj.start_row - 1 + obj.mask_width;			% + obj.mask_width to restore the row position in the input
				fire_stats.estimated_sigma = obj.estimated_sigma;
				fprintf('scale 0 fires with maximal response = %.02f row = %d\n', maxval, fire_stats.rowid);
				return;
			end
			fprintf('scale 0 fails to fire with maximal response = %.02f\n', maxval);

			end

			% other scales
			fire_stats_cell = cell(1, obj.max_scale);
			for scale = 1:obj.max_scale

				num_nodes = (2^scale)+1;
				seuil = obj.threshold * sqrt(num_nodes);
				maxret = spline_integral(obj.fullstrip, scale, obj.poly_controls{scale}, ...
									obj.start_row, obj.num_rowls, seuil);

				% a row in maxret.vals is formatted as [response, abs_response, rows, contid]
				if maxret.maxval > seuil
					fire_stats.scale = scale;
					fire_stats.signed_response = maxret.vals(:, 1);
					fire_stats.rowid = maxret.vals(:, 3) + obj.mask_width;				% + obj.mask_width to restore the row position in the input
					fire_stats.contid = maxret.vals(:, 4);
					fire_stats.controls = obj.poly_controls{scale}(fire_stats.contid, :);
					fire_stats.coefs = obj.poly_coefs{scale}(fire_stats.contid, :);
					fire_stats.estimated_sigma = obj.estimated_sigma;
					for f = 1:size(fire_stats.rowid, 1)
						fprintf('scale %d fires with maximal response = %.02f row = %d contid = %d\n', ...
							scale, maxret.maxval/sqrt(num_nodes), fire_stats.rowid(f), fire_stats.contid(f, :));
					end
					fire_stats_cell{scale} = fire_stats;
					% enable to return at the first firing scale
					% return;
				else
					fprintf('scale %d fails to fire with maximal response = %.02f\n', ...
						scale, maxret.maxval/sqrt(num_nodes));
				end
			end

		end

		function estimated_sigma = estimate_noise(~, suggested_sigma, pixel_responses)

			if suggested_sigma > 0

				fprintf('\nnoise sigma is set externally to %.02f\n', suggested_sigma);
				estimated_sigma = suggested_sigma;

			else
				fprintf('\nestimating the noise sigma from the observed pixel responses\n');
				Z = abs(pixel_responses(:));
				estimated_sigma = median(Z(Z<median(Z)*4))/norminv(0.75);
				fprintf('estimated noise sigma is %.02f\n\n', estimated_sigma);
			end

		end


		% max_scale polynomial templates, one per scale, for edge detection
		function [poly_coefs, poly_controls, max_slopes, max_reaches] = generate_templates(obj)

			% a polynomial approximation for b-regular edges at various scales

			poly_coefs = cell(1, obj.max_scale);
			poly_controls = cell(1, obj.max_scale);
			max_slopes = zeros(1, obj.max_scale);

			% keep the candidate edges from the strip's upper and lower borders
			max_reaches = zeros(1, obj.max_scale);

			ndegrees = numel(obj.b_bounds);

			for scale = 1:obj.max_scale

				relength = (2^(scale-1))/obj.image_size;
				coef_range = cell(1, ndegrees);

				for degree = 1:ndegrees

					% obj.b_bounds = [b_r, ..., b_1]
					index = ndegrees - degree + 1;
					multiplier = obj.b_bounds(index) * obj.quantizer * obj.image_size;
					max_coef = floor(0.5 + (relength^degree)*multiplier);

					% the coefs are formatted as [a_r, ..., a_1]
					coef_range{index} = -max_coef:max_coef;

					% worst case slope and function values
					max_slopes(scale) = max_slopes(scale) + max_coef*degree * (2^(1-scale));
					max_reaches(scale) = max_reaches(scale) + max_coef;
				end

				% cartesian product of higher order coefs
				coefs = combvec(coef_range{:})'/obj.quantizer;

				% quantize the polynomial's constant
				constants = kron((0:(obj.quantizer-1))', ones(size(coefs, 1), 1))/obj.quantizer;
				coefs = horzcat(repmat(coefs, obj.quantizer, 1), constants);

				% each row represents a coef combination
				fprintf('originally %d templates at scale %d\n', size(coefs, 1), scale);

				% evaluate the interpolated polynomial values over -1:spacing:1
				num_intervals = 2^(max(scale-obj.interp_scale, 0) + 1);
				nodes = linspace(-1, 1, num_intervals + 1);
				controls = floor(obj.eval_polynomials(coefs, nodes));

				% remove the possible duplicates created by discretization
				[controls, uniq_rows] = unique(controls, 'rows');
				coefs = coefs(uniq_rows, :);
				fprintf('shrunk to %d templates at scale %d\n', size(coefs, 1), scale);

				poly_coefs{scale} = coefs;
				poly_controls{scale} = controls;

			end

			max_slopes = max_slopes/obj.quantizer;
			max_reaches = max_reaches/obj.quantizer;

		end

		function vals = eval_polynomials(~, coefs, nodes)

			% evaluate the polynomials over the nodes
			% a row in coefs represents the vector [a_r, ..., a_0]

			L = numel(nodes);
			vals = coefs(:, 1) * ones(1, L);

			for k = 2:size(coefs, 2)
				vals = bsxfun(@plus, coefs(:, k), bsxfun(@times, nodes, vals));
			end

			%%%% sanity check
			%%% result = zeros(size(coefs, 1), L);
			%%% for r = 1:size(coefs, 1)
			%%% 	result(r, :) = polyval(coefs(r, :), nodes);
			%%% end
			%%% assert(isequal(result, vals));

		end

		% visualize the result
		function courbes = visualize(obj, fire_stats, background, midcol)

			if nargin == 2
				fprintf('\nNO background provided.\n');
				background = repmat(obj.noisy_strip, [1, 1, 3]);
			else
				if length(size(background)) == 2
					fprintf('\nBackground image must be RGB. Discard the input background.\n');
					background = repmat(obj.noisy_strip, [1, 1, 3]);
				end
			end

			% background editing

			% curve drawing

			if nargin == 3
				midcol = 2^(obj.max_scale-1) + 1;
			end

			scale = fire_stats.scale;
			num_fires = numel(fire_stats.rowid);

			if scale == 0

				for f = 1:num_fires
				refrow = fire_stats.rowid(f);
				background(refrow, midcol, 1) = 0;
				background(refrow, midcol, 2) = 250;
				background(refrow, midcol, 3) = 0;
				end

			else
				half_size = 2^(scale-1);
				colids = (-half_size:half_size) + midcol;

				controls = fire_stats.controls;
				rowdif = diff(controls, 1, 2);
				interp_length = 2^(min(scale, obj.interp_scale) - 1);
				increment = ones(1, interp_length)/interp_length;

				rowids = floor(cumsum(horzcat(controls(:, 1), kron(rowdif, increment)), 2));
				rowids = bsxfun(@plus, fire_stats.rowid, rowids);

				% for f = 1:size(rowids, 1)
				% for k = 1:size(rowids, 2)
				% 	background(rowids(f, k), colids(k), 1) = 0;
				% 	background(rowids(f, k), colids(k), 2) = 255;
				% 	background(rowids(f, k), colids(k), 3) = 0;
				% end
				% end

			end

			courbes.rowids = rowids;
			courbes.colids = colids;

			% figure; imshow(uint8(background)); title(sprintf('detected (green) at scale %d', scale));
			% imwrite(uint8(background), 'detected.png');
		end

	end

end
