#ifndef IMF_H
#define IMF_H

#include "bilinear.h"

struct imf_prob_t { // {{{
	smat_t *Y;
	gmat_t *A;     // row features
	gmat_t *B;     // column features
	smat_t *La;    // graph regularizer for W
	smat_t *Lb;    // graph regularizer for H
	size_t m, n;   // dimension of Y
	size_t da, db; // #features of A and B
	size_t k;      // #topics
	imf_prob_t():Y(NULL),A(NULL),B(NULL){}
	imf_prob_t(smat_t *Y, gmat_t *A, gmat_t *B, size_t k, smat_t *La=NULL, smat_t *Lb=NULL):
		Y(Y), A(A), B(B), La(La), Lb(Lb), m(Y->rows), n(Y->cols), da(A->cols), db(B->cols), k(k){ // {{{
			if(La!=NULL) {
				assert(La->rows == da);
				assert(La->cols == da);
			}
			if(Lb!=NULL) {
				assert(Lb->rows == db);
				assert(Lb->cols == db);
			}
		} // }}}
};

struct imf_param_t : public bilinear_param_t {
	int top_p;
	imf_param_t(): bilinear_param_t() { top_p = 5; }
}; // }}}

class imf_ranker_t : pmf_ranker_t { // {{{
	typedef pmf_ranker_t Base;
	private :
		smat_t *Y;
		const smat_t *ignored;
		gmat_t *A;
		gmat_t *B;
	public :
		dmat_t AW;
		dmat_t BH;

		imf_ranker_t(){}
		imf_ranker_t(const imf_prob_t &prob, const imf_param_t &param, smat_t* ignored = NULL,
				double neutral_rel_=0, double halflife_=5): Base(neutral_rel_, halflife_), Y(prob.Y), ignored(ignored), A(prob.A), B(prob.B) { // {{{
			if(!A->is_identity())
				AW = dmat_t(A->rows, param.k, ROWMAJOR);
			if(!B->is_identity())
				BH = dmat_t(B->rows, param.k, ROWMAJOR);
			if(ignored && ignored->nnz > 0) {
				if(ignored->rows != Y->rows)
					fprintf(stderr,"row dimension of ignored matrix is wrong\n");
				if(ignored->cols != Y->cols)
					fprintf(stderr,"col dimension of ignored matrix is wrong\n");
			}
		} // }}}
		void preprocess(const dmat_t &W, const dmat_t &H) { // {{{
			if(A->is_dense()) {
				dmat_x_dmat(A->get_dense(), W, AW);
			} else if (A->is_identity()) {
				//AW.assign(W);
				AW = W.get_view();
			} else if (A->is_sparse()) {
				smat_x_dmat(A->get_sparse(), W, AW);
			}
			if(B->is_dense()) {
				dmat_x_dmat(B->get_dense(), H, BH);
			} else if (B->is_identity()) {
				//BH.assign(H);
				BH = H.get_view();
			} else if (B->is_sparse()) {
				smat_x_dmat(B->get_sparse(), H, BH);
			}
		} // }}}
		info_t get_info(int top_p) { // {{{
			//printf("AW %g BH %g\n", do_dot_product(AW,AW), do_dot_product(BH,BH));
			if(ignored)
				return Base::eval<int,unsigned>(*Y,AW,BH,top_p,*ignored);
			else
				return Base::eval<int,unsigned>(*Y,AW,BH,top_p);
		} // }}}
		double get_rmse() { // {{{
			return pmf_compute_rmse(*Y, AW, BH);
		} // }}}
}; //}}}


#ifdef __cplusplus
extern "C" {
#endif

void imf_train(imf_prob_t *prob, imf_param_t *param, dmat_t *ptr_W, dmat_t *ptr_H, imf_prob_t *test_prob=NULL, info_t *info =NULL, double *rmse=NULL);

void ccdr1_imf(imf_prob_t *prob, imf_param_t *param, dmat_t *ptr_W, dmat_t *ptr_H, imf_prob_t *test_prob, info_t *info, double *rmse);

#ifdef __cplusplus
}
#endif

#endif // IMF_H
