/*
	a function which performs Brandt/Dym efficient line integral
	see Fast calculation of multiple line integrals, 1999

	Written by Yi-Qing Wang, 2016
*/

#include <mex.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#include "csweep.h"

/*further cut the time on memory allocation*/
static size_t * valid_positions = NULL;
static size_t counter = 0;

static optional_double * calculated = NULL;
static size_t * num_offsets_at_scale = NULL;
static size_t * num_data_per_ssubstrip = NULL;

void mexFunction(
    int nlhs,
    mxArray *plhs[],
    int nrhs,
    const mxArray *prhs[]
)
{
    /*a particular signal to end everything*/
    if (nrhs + nlhs == 0) {
        mxFree(calculated);
        mxFree(num_offsets_at_scale);
        mxFree(num_data_per_ssubstrip);
        mxFree(valid_positions);
        calculated = NULL;
        return;
    }

    int scale = (int) mxGetScalar(prhs[0]);
    int interp_scale = (int) mxGetScalar(prhs[1]);
    double max_slope = fabs(mxGetScalar(prhs[2]));

    size_t start_row = (size_t) mxGetScalar(prhs[3]);
    size_t num_rowls = (size_t) mxGetScalar(prhs[4]);
    size_t start_col = (size_t) mxGetScalar(prhs[5]);

    double * controls = mxGetPr(prhs[6]);
    size_t num_controls = mxGetM(prhs[6]);
    size_t num_lps = mxGetN(prhs[6]) - 1;

    double * input_strip = mxGetPr(prhs[7]);
    size_t n = mxGetM(prhs[7]);

    bool free_memory = mxGetScalar(prhs[8]) > 0.5;

    size_t max_num_lps = (size_t) mxGetScalar(prhs[9]);

    size_t num_responses = num_rowls * num_controls;
    plhs[0] = mxCreateDoubleMatrix(1, num_responses, mxREAL);
    double * responses = mxGetPr(plhs[0]);


    /*
    recursion memory allocation

    memory layout: the line integral at substrip id >=0
    of scale s >=1 spanning pow(2, s-1) + 1 columns
    connecting two rows (rowl >=0, rowr >=0) at sub-substrip recur_idx

    s = 1
    substrip_id * num_data_per_ssubstrip[interp_scale-1]
    + recur_idx * n * num_offsets_at_scale[s]
    + num_offsets_at_scale[s] * rowl
    + (num_offsets_at_scale[s] - 1)/2 + rowr - rowl

    s > 1
    above + num_data_per_ssubstrip[s-2]
    */

    int substrip_id, s;
    if (calculated == NULL) {
        num_offsets_at_scale = (size_t *) mxMalloc(interp_scale * sizeof(size_t));
        mexMakeMemoryPersistent(num_offsets_at_scale);

        num_data_per_ssubstrip = (size_t *) mxMalloc(interp_scale * sizeof(size_t));
        mexMakeMemoryPersistent(num_data_per_ssubstrip);

        for (s = 0; s < interp_scale; s++) {
            int num_offsets = (1 << s) * max_slope + 1;
            num_offsets_at_scale[s] = num_offsets * 2 + 1;
            int num_divisions = 1 << (interp_scale - s - 1);
            num_data_per_ssubstrip[s] = num_offsets_at_scale[s] * n * num_divisions;
            num_data_per_ssubstrip[s] += s > 0 ? num_data_per_ssubstrip[s-1] : 0;
        }

        calculated = (optional_double *) mxCalloc(num_data_per_ssubstrip[interp_scale-1]*max_num_lps, sizeof(optional_double));
        mexMakeMemoryPersistent(calculated);

        valid_positions = (size_t *) mxMalloc(num_data_per_ssubstrip[interp_scale-1] * max_num_lps * sizeof(size_t));
        mexMakeMemoryPersistent(valid_positions);
    }

    size_t k;
    /*brandt dym recursion*/
    for (substrip_id = 0; substrip_id < num_lps; substrip_id++)
        for (k = 0; k < num_responses; k++) {
            size_t contid = k / num_rowls + substrip_id * num_controls;
            size_t base_row = k % num_rowls + start_row;
            size_t rowl = base_row + controls[contid];
            size_t rowr = base_row + controls[contid + num_controls];
            responses[k] += line_integral(scale, interp_scale, max_slope, rowl, rowr, 0, substrip_id, input_strip, n, start_col);
        }


    if (free_memory) {
        for (k = 0; k < counter; k++)
            calculated[valid_positions[k]].is_valid = false;
        counter = 0;
    }

}

double line_integral (
    int scale,					/* current scale ranging from 1 to interp_scale			*/
    int interp_scale,				/* interpolation scale						*/
    double max_slope,
    size_t rowl,				/* l.h.s. row coordinate of the segment				*/
    size_t rowr,				/* r.h.s. row coordinate of the segment				*/
    int recur_idx,				/* binary tree index						*/
    int substrip_id,				/* index of the substrips of the input_strip			*/
    double * input_strip,			/* actual input image strip					*/
    size_t n,					/* number of rows of input_strip				*/
    size_t start_col				/* starting column						*/
)
{
    /*

    suppose the input_strip has 2^(max_scale-1) + 1 columns, each having n entries, it can be divided into
    2^(max_scale - interp_scale) substrips, which are indexed by substrip_id

    the maximum number of segments originating from a rowl at scale s is 2 * ceil( 2^(s-1) * max_slope ) + 1
    within each substrip, there are 2^(interp_scale - s) sub-matrices of scale s, indexed by recur_idx

    calculated[substrip_id][scale-1] is a vector concatenating 2^(interp_scale - scale) sub-vectors
    each representing a ceil( 2^(scale - 1) * abs(max_slope) ) + 1 by n matrix

    */


    size_t position = substrip_id * num_data_per_ssubstrip[interp_scale-1]
                      + recur_idx * n * num_offsets_at_scale[scale-1]
                      + num_offsets_at_scale[scale-1] * rowl
                      + (num_offsets_at_scale[scale-1] - 1)/2 + rowr - rowl;
    position += scale > 1 ? num_data_per_ssubstrip[scale-2] : 0;


    if (calculated[position].is_valid)
        return calculated[position].value;

    /* recursion */
    double value = 0;
    if (scale == 1) {
        /* int shift = start_col + pow(2, interp_scale-1) * substrip_id; */
        int shift = start_col + (1 << (interp_scale-1)) * substrip_id;
        size_t l_index = n * (shift + recur_idx) + rowl;
        size_t r_index = n * (shift + recur_idx + 1) + rowr;
        value = 0.5 * (input_strip[l_index] + input_strip[r_index]);
    } else {
        size_t mid = (rowl + rowr) * 0.5;
        value = line_integral(scale-1, interp_scale, max_slope, rowl, mid, recur_idx*2, substrip_id, input_strip, n, start_col)
                + line_integral(scale-1, interp_scale, max_slope, mid, rowr, recur_idx*2+1, substrip_id, input_strip, n, start_col);
    }

    calculated[position].value = value;
    calculated[position].is_valid = true;

    valid_positions[counter] = position;
    counter += 1;

    return value;

}
