#include "pmf_matlab.hpp"

static void fake_answer(int nlhs, mxArray *plhs[]) { // {{{
    for(int i = 0; i < std::max(1, nlhs); i++) {
        plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
    }
} // }}}

template<typename T>
static void dmat_rand_init(T& X, size_t rows, size_t cols, major_t major_type=COLMAJOR, long seed = 0) { // {{{
	val_type scale = 1./sqrt(cols);
	rng_t rng(seed);
	X = T(rows, cols, major_type);
	for(size_t i = 0; i < rows; i++)
		for(size_t j = 0; j < cols; j++)
			X.at(i,j) = (val_type) rng.uniform((val_type)0.0, scale);
} // }}}

/*
imf_train(Y, A, B);  3
imf_train(Y, A, B, 'options'); 4
imf_train(Y, A, B, W, H); 5
imf_train(Y, A, B, W, H, 'options'); 6
imf_train(Y, A, B, testY, testA, testB, 'options'); 7
imf_train(Y, A, B, W, H, testY, testA, testB, 'options'); 9
*/
void exit_with_help() { // {{{
	mexPrintf(
	"Usage: [W H rmse walltime] = imf_train(Y, A, B [, 'options'])\n"
	"       [W H rmse walltime] = imf_train(Y, A, B, W, H [, 'options'])\n"
	"       [W H rmse walltime] = imf_train(Y, A, B, testY, testA, testB , 'options')\n"
	"       [W H rmse walltime] = imf_train(Y, A, B, W, H, testY, testA, testB, 'options')\n"
	"  Y: m-by-n sparse matrix or an nnz-by-3 dense matrix [I J V], where [I J V] = find(Y)\n"
	"  A, B: feature matrix: sparse or dense or use [] to denote identity\n"
	"  testY,testA,testB: test tuple, 'options' is required to enable testing\n"
	"options:\n"
	"    -s type : set type of solver (default 10)\n"
	"        1 -- ALS (L2R-LS)\n"
	"        3 -- LR-ALS (L2R-ALS) 1.0 obs for postive loss, rho for obs negative loss\n"
	"       10 -- PU-CCD (L2R-LS-PU)\n"
	"       11 -- PU-ALS (L2R-LS-PU)\n"
	"       13 -- LR-PU-ALS (L2R-LR-PU)\n"
	"    -D type : set type of descent solver (default 1)\n"
	"        0 -- gradient descent with line search (GD_LS)\n"
	"        1 -- truncated newton method with trust regsion (TRON_TR)\n"
	"        2 -- truncated newton method with line search (TRON_LS)\n"
	"    -n threads : set the number of threads (default 4)\n"
	"    -k rank : set the rank (default 10)\n"
	"    -l lambda : set the regularization parameter (default 0.1)\n"
	"    -r rho : set the parameter rho for PU formulation (default 0.01)\n"
	"    -a abar: set the default value for the unlabled entries (default 1)\n"
	"    -e epsilon : set stopping criterion epsilon of tron (default 0.1)\n"
	"    -t max_iter: set the number of iterations (default 10)\n"
	"    -T max_tron_iter: set the number of iterations used in TRON (default 1)\n"
	"    -g max_cg_iter: set the number of iterations used in CG (default 10)\n"
	"    -f use_chol: use cholesky factorization whenever possible (default 0)\n"
	"    -q verbose: show information or not (default 1)\n"
	"    -p do_predict: (default 0)\n"
	"        1 -- RMSE \n"
	"        2 -- Ranking Evaluation\n"
	"        3 -- Both\n"
	"    -P top_P: top-P ranking evaluation (default 5)\n"
	);
} // }}}

imf_param_t parse_command_line(int nrhs, const mxArray *prhs[]) { // {{{
	imf_param_t param;   // default values have been set by the constructor
	param.verbose = 1;
	int i, argc = 1;
	int option_pos = -1;
	char cmd[CMD_LEN];
	char *argv[CMD_LEN/2];

	if(nrhs == 3 || nrhs == 5)
		return param;
	if(nrhs == 4 || nrhs == 6 || nrhs == 7 || nrhs == 9)
		option_pos = nrhs-1;

	// put options in argv[]
	if(option_pos>0) { // {{{
		mxGetString(prhs[option_pos], cmd,  mxGetN(prhs[option_pos]) + 1);
		if((argv[argc] = strtok(cmd, " ")) != NULL)
			while((argv[++argc] = strtok(NULL, " ")) != NULL)
				;
	} // }}}

	// parse options
	for(i=1;i<argc;i++) { // {{{
		if(argv[i][0] != '-') break;
		if(++i>=argc)
			exit_with_help();
		switch(argv[i-1][1]) {
			case 's':
				param.solver_type = atoi(argv[i]);
				break;

			case 'D':
				param.solver_descend_type = atoi(argv[i]);
				break;

			case 'k':
				param.k = atoi(argv[i]);
				break;

			case 'n':
				param.threads = atoi(argv[i]);
				break;

			case 'l':
				param.lambda = atof(argv[i]);
				break;

			case 'r':
				param.rho = atof(argv[i]);
				break;

			case 'a':
				param.abar = atof(argv[i]);
				break;

			case 't':
				param.maxiter = atoi(argv[i]);
				break;

			case 'T':
				param.max_tron_iter = atoi(argv[i]);
				param.maxinneriter = atoi(argv[i]);
				break;

			case 'g':
				param.max_cg_iter = atoi(argv[i]);
				break;

			case 'e':
				param.eps = atof(argv[i]);
				break;

			case 'q':
				param.verbose = atoi(argv[i]);
				break;

			case 'f':
				param.use_chol = atoi(argv[i]);
				break;

			case 'p':
				param.do_predict = atoi(argv[i]);
				break;

			case 'P':
				param.top_p = atoi(argv[i]);
				break;

			case 'u':
				param.pu_type = atoi(argv[i]);
				break;

			default:
				mexPrintf("unknown option: -%c\n", argv[i-1][1]);
				exit_with_help();
				break;
		}
	} // }}}

	if(param.do_predict && !(nrhs==7 || nrhs==9))
		param.do_predict = 0;
	if(param.do_predict && param.verbose == 0)
		param.verbose = 1;

	// For Squared-L2 loss, use CG is enough
	if(param.solver_type == ALS || param.solver_type == PU_ALS) {
		param.max_cg_iter = param.max_tron_iter*param.max_cg_iter;
		param.max_tron_iter = 1;
	}
	omp_set_num_threads(param.threads);

	return param;
} // }}}

int run_imf_train(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[], imf_param_t &param) { // {{{
	const mxArray *mxY=prhs[0], *mxA=prhs[1], *mxB=prhs[2], *mxinitW=NULL, *mxinitH=NULL;
	const mxArray *mxtestY=NULL, *mxtestA=NULL, *mxtestB=NULL;
	smat_t Y, testY;

	// convert Y,A,B
	size_t tmp_m = mxGetM(mxA), tmp_n = mxGetM(mxB);
	mxArray_to_smat(mxY, Y, tmp_m, tmp_n);
	major_t major = ROWMAJOR; // use COLMAJOR if CCD-type solvers are used
	switch (param.solver_type) { // {{{
		case CCDR1:
		case CCDR1_SPEEDUP:
		case PU_CCDR1:
			major = COLMAJOR;
			break;
	} // }}}
	feature_matrix fA(mxA, Y.rows, major), fB(mxB, Y.cols, major), testfA, testfB;
	gmat_t &A = fA.get_gmat_ref();
	gmat_t &B = fB.get_gmat_ref();

	size_t m = Y.rows, n = Y.cols;
	size_t da = A.cols, db = B.cols;
	size_t k = param.k;

	// Initialize W and H
	// fix random seed to have same results for each run  (for random initialization)
	long seed = 0L;
	dmat_t W, H;
	if(nrhs == 5 || nrhs == 6 || nrhs == 9) { // {{{
		mxinitW = prhs[3];
		mxinitH = prhs[4];
		if(mxGetM(mxinitW) != da || mxGetM(mxinitH) != db || mxGetN(mxinitW) != mxGetN(mxinitH)) {
			mexPrintf("Error: dimensions of (A,B,W,H) do not match !\n");
			fake_answer(nlhs, plhs);
			return -1;
		}
		if(mxGetN(mxinitW) != k) {
			k = param.k = mxGetN(mxinitW);
			fake_answer(nlhs, plhs);
			mexPrintf("Warning: Change param.k to %ld to match W0 and H0\n", k);
		}
		double *initW = mxGetPr(mxinitW), *initH = mxGetPr(mxinitH);
		mxDense_to_dmat(mxinitW, W, major);
		mxDense_to_dmat(mxinitH, H, major);
	} else {
		dmat_rand_init(W, da, k, major, seed);
		dmat_rand_init(H, db, k, major, seed+da);
	} // }}}

	imf_prob_t prob(&Y, &A, &B, k), test_prob;

	// handle cases with test data
	if(param.do_predict && (nrhs == 7 || nrhs == 9)) { // {{{
		// convert testY,testA,testB
		int offset = (nrhs == 7)? 0: 2;
		mxtestY=prhs[3+offset];
		mxtestA=prhs[4+offset];
		mxtestB=prhs[5+offset];

		tmp_m = mxGetM(mxtestA); tmp_n = mxGetM(mxtestB);
		mxArray_to_smat(mxtestY, testY, tmp_m, tmp_n);
		testfA.convert_from(mxtestA, testY.rows, ROWMAJOR);
		testfB.convert_from(mxtestB, testY.cols, ROWMAJOR);
		if(testfA.cols != A.cols) {
			mexPrintf("Error: dimensions do not match size(A,2) != size(testA,2) %ld %ld!\n",testfA.cols, A.cols);
			fake_answer(nlhs, plhs);
			return -1;
		}
		if(testfB.cols != B.cols) {
			mexPrintf("Error: dimensions do not match size(A,2) != size(testA,2)!\n");
			fake_answer(nlhs, plhs);
			return -1;
		}
		test_prob = imf_prob_t(&testY, testfA.get_gmat_ptr(), testfB.get_gmat_ptr(), k);
	} //}}}

	double rmse = 0;
	double walltime = 0;

	info_t info;
	walltime = omp_get_wtime();
	imf_train(&prob, &param, &W, &H, (test_prob.Y != NULL)? &test_prob : NULL, &info, &rmse);
	walltime = omp_get_wtime() - walltime;

    if(nlhs >= 1) {
        dmat_to_mxDense(W, plhs[0]=mxCreateDoubleMatrix(da,k,mxREAL));
    }
    if(nlhs >= 2) {
        dmat_to_mxDense(H, plhs[1]=mxCreateDoubleMatrix(db,k,mxREAL));
    }
    if(nlhs >= 3) {
        *mxGetPr(plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL)) = rmse;
    }
    if(nlhs >= 4) {
        *mxGetPr(plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL)) = walltime;
    }

	return 0;
} // }}}

// Interface function of matlab
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) { // {{{
	imf_param_t param;
	// fix random seed to have same results for each run
	// (for cross validation)
	srand(1);

	switch (nrhs) {
		case 3: case 4: case 5: case 6: case 7: case 9:
			{ // {{{
				param = parse_command_line(nrhs, prhs);
				switch (param.solver_type){
					case ALS:
					case PU_CCDR1:
					case PU_ALS:
					case LR_ALS:
					case LR_PU_ALS:
						run_imf_train(nlhs, plhs, nrhs, prhs, param);
						break;
					default:
						fprintf(stderr, "Error: wrong solver type (%d)!\n", param.solver_type);
						exit_with_help();
						fake_answer(nlhs, plhs);
						break;
				}
			} // }}}
			break;
		default :
			exit_with_help();
			fake_answer(nlhs, plhs);
	}
} // }}}
