#ifndef BILINEAR_H
#define BILINEAR_H

#include "tron.h"
#include "pmf.h"

struct bilinear_prob_t { // {{{
	smat_t *Y; // m*n sparse matrix
	gmat_t *X; // m*d feature matrix
	dmat_t *H; // n*k dense matrix
	smat_t *L; // d*d sparse matrix for graph regularization
	const size_t m;  // #instances
	const size_t d;  // #features
	const size_t n;  // #labelss
	const size_t k;  // #row-rank dimension
	dmat_t p;          // m*1 : could be param->rho*ones(m,1)
	dmat_t q;          // n*1 : could be ones(n,1)
	bilinear_prob_t(smat_t* Y, gmat_t* X, dmat_t* H, smat_t* L=NULL):
		Y(Y), X(X), H(H), L(L), m(Y->rows), n(Y->cols), d(X->cols), k(H->cols){ // {{{
			if(L!=NULL) {
				assert(L->rows == d);
				assert(L->cols == d);
			}
	} // }}}
}; // }}}

struct bilinear_param_t : public pmf_parameter_t { // {{{
	int solver_descend_type;
	double abar; // default value for missing entries
	double lambda_graph;
	double eps_cg;
	int max_tron_iter, max_cg_iter;
	int weighted_reg;
	int use_chol;
	bilinear_param_t(): pmf_parameter_t() {
		solver_descend_type = TRON_TR;
		abar = 0.0;
		lambda_graph  = 0;
		eps_cg = 0.1;
		max_tron_iter = 1;
		max_cg_iter = 10;
		weighted_reg = 0;
		use_chol = 0;
	}
}; // }}}

struct solver_t { // {{{
	virtual void init_prob() = 0;
	virtual void solve(val_type *w) = 0;
	virtual double fun(val_type *w) {return 0;}
	virtual ~solver_t(){}
}; // }}}

struct leml_solver : public solver_t { // {{{
	bilinear_prob_t *prob;
	bilinear_param_t *param;
	function<val_type> *fun_obj;
	TRON<val_type> *tron_obj;
	solver_t *solver_obj;
	bool done_init;


	leml_solver(bilinear_prob_t *prob, bilinear_param_t *param);
	leml_solver(const leml_solver& other) {}

	void zero_init() { // {{{
		prob = NULL;
		param = NULL;
		fun_obj = NULL;
		tron_obj = NULL;
		solver_obj = NULL;
		done_init = false;
	} // }}}
	~leml_solver() { // {{{
		if(tron_obj) delete tron_obj;
		if(fun_obj) delete fun_obj;
		if(solver_obj) delete solver_obj;
		zero_init();
	} // }}}
	void init_prob() { // {{{
		if(fun_obj) fun_obj->init();
		else if(solver_obj) solver_obj->init_prob();
		done_init = true;
	} // }}}
	void set_eps(double eps) { tron_obj->set_eps(eps);}
	void solve(val_type *w) { // {{{
		if(!done_init) {init_prob();}
		if(tron_obj) {
			// tron_obj->tron(w, true);// zero-initization for w
			tron_obj->tron(w, false, param->solver_descend_type);
		} else if (solver_obj) {
			solver_obj->solve(w);
		}
	} // }}}
	double fun(val_type *w) { // {{{
		if(!done_init) { init_prob(); }
		if(fun_obj)
			return fun_obj->fun(w);
		else if(solver_obj)
			return solver_obj->fun(w);
		else
			return 0;
	} // }}}
}; // }}}


#endif // BILINEAR_H
