#include "imf.h"

static double norm(const dmat_t &W) { // {{{
	double ret = do_dot_product(W.data(),W.data(), W.rows*W.cols);
	return sqrt(ret);
} // }}}
static double norm(const dvec_t &v) { // {{{
	return norm(dmat_t(v));
} // }}}

void imf_train(imf_prob_t *prob, imf_param_t *param, dmat_t *ptr_W, dmat_t *ptr_H, imf_prob_t *test_prob, info_t *info, double *rmse) { // {{{
	*rmse = 0;

	if(param->solver_type == CCDR1 || param->solver_type == PU_CCDR1)
		return ccdr1_imf(prob, param, ptr_W, ptr_H, test_prob, info, rmse);

	smat_t &Y = *(prob->Y), Yt = Y.transpose();
	gmat_t &A = *(prob->A);
	gmat_t &B = *(prob->B);
	dmat_t &W = *ptr_W;
	dmat_t &H = *ptr_H;
	const size_t &m = prob->m, &n = prob->n, &da = prob->da, &db = prob->db;
	const size_t &k = prob->k;

	imf_ranker_t ranker;
	bool do_rmse = (param->do_predict & 1) != 0;
	bool do_ranking = (param->do_predict & 2) != 0;

	omp_set_num_threads(param->threads);

	dmat_t AW, BH;
	if(A.is_dense()) {
		AW = dmat_t(m, k, ROWMAJOR);
	} else if(A.is_identity()) {
		//AW = dmat_t(m, k, ROWMAJOR);
		AW = W.get_view();
	} else if(A.is_sparse()) {
		AW = dmat_t(m, k, ROWMAJOR);
	}

	if(B.is_dense()) {
		BH = dmat_t(n, k, ROWMAJOR);
	} else if(B.is_identity()) {
		BH = H.get_view();
		//BH = dmat_t(n, k, ROWMAJOR);
	} else if(B.is_sparse()) {
		BH = dmat_t(n, k, ROWMAJOR);
	}

	if(test_prob != NULL && test_prob->Y != NULL) {
		// ignored training entries for test evaluation when both A and B are identity
		smat_t *ignored = (A.is_identity() && B.is_identity() && test_prob->A->is_identity() && test_prob->B->is_identity())? &Y : NULL;
		ranker = imf_ranker_t(*test_prob, *param, ignored);
	}

	bilinear_prob_t subprob_w(&Y, &A, &BH, prob->La);
	bilinear_prob_t subprob_h(&Yt,&B, &AW, prob->Lb);
	leml_solver W_solver(&subprob_w, param);
	leml_solver H_solver(&subprob_h, param);

	memset(H.data(), 0, sizeof(val_type)*H.rows*H.cols);

	if(param->verbose != 0) {
		printf("|W0| (%ld %ld)= %.6f\n", W.rows, W.cols, norm(W));
		printf("|H0| (%ld %ld)= %.6f\n", H.rows, H.cols, norm(H));
	}

	double Wtime=0, Htime=0, start_time=0;
	for(int iter = 1; iter <= param->maxiter; iter++) {
		/*
		if(iter == 5) {
			W_solver.set_eps(param->eps);
			H_solver.set_eps(param->eps);
		} else if(iter == 1) {
			W_solver.set_eps(0.1);
			H_solver.set_eps(0.1);
		}
		*/

		start_time = omp_get_wtime();
		if(A.is_dense()) {
			dmat_x_dmat(A.get_dense(), W, AW);
		} else if(A.is_sparse()){
			smat_x_dmat(A.get_sparse(), W, AW);
		}
		H_solver.init_prob();
		H_solver.solve(H.data());
		Htime += omp_get_wtime()-start_time;

		start_time = omp_get_wtime();
		if(B.is_dense()) {
			dmat_x_dmat(B.get_dense(), H, BH);
		} else if(B.is_sparse()){
			smat_x_dmat(B.get_sparse(), H, BH);
		}
		//printf("W %.10g H %.10g BH %.10g\n", norm(W), norm(H), norm(BH));
		W_solver.init_prob();
		W_solver.solve(W.data());
		Wtime += omp_get_wtime()-start_time;

		if(param->verbose != 0) { // {{{
			printf("IMF-iter %d W %.5g H %.5g Time %.5g",
					//iter, Wtime, Htime, Wtime+Htime);
					iter, norm(W),norm(H), Wtime+Htime);
			double reg_h = 0, reg_w = 0, reg_graph_h = 0, reg_graph_w = 0;
			switch (param->solver_type) {
				case ALS:
				case PU_ALS:
				case LR_PU_ALS:
					reg_w = do_dot_product(W.data(),W.data(),W.rows*W.cols);
					reg_h = do_dot_product(H.data(),H.data(),H.rows*H.cols);
					if(param->lambda_graph != 0 && prob->La != NULL)
						reg_graph_w = trace_dmat_T_smat_dmat(W, *(prob->La), W);
					if(param->lambda_graph != 0 && prob->Lb != NULL)
						reg_graph_h = trace_dmat_T_smat_dmat(H, *(prob->Lb), H);
					break;
			}
			double loss = (H_solver.fun(H.data())-0.5*param->lambda*reg_h-0.5*param->lambda_graph*reg_graph_h);
			//double loss = (W_solver.fun(W.data())-0.5*param->lambda*reg_w-0.5*param->lambda_graph*reg_graph_w);

			//W_solver.init_prob();
			//double loss = 2*(W_solver.fun(W)-0.5*param->lambda*reg_w);
			double obj = loss+0.5*param->lambda*(reg_w+reg_h)+0.5*param->lambda_graph*(reg_graph_w+reg_graph_h);
			printf(" loss %g reg %g obj %.5g", loss, reg_w+reg_h+reg_graph_w+reg_graph_h, obj);

			if(test_prob != NULL && test_prob->Y != NULL) { // {{{
				ranker.preprocess(W, H);
				if(do_rmse) {
					double tmp_rmse = ranker.get_rmse();
					if(rmse!=NULL) *rmse = tmp_rmse;
					printf(" rmse %.5g", tmp_rmse);
				}
				if(do_ranking) {
					info_t tmp_info = ranker.get_info(param->top_p);
					if(info!=NULL) *info= tmp_info;
					tmp_info.print();
				}
			} // }}}
			puts("");
			fflush(stdout);
		} // }}}
	}
} // }}}

// Y R X W, B, H, uc, udiag, uplus, uminus,
struct leml_rankone_solver { // {{{
	typedef bilinear_prob_t prob_t;
	typedef bilinear_param_t param_t;
	const prob_t *prob;
	const param_t *param;
	const smat_t &Y;
	const gmat_t &X;
	const dmat_t &H;
	smat_t &Dplus;
	dmat_t &B;
	dvec_t bc, k_buf, m_buf;
	dvec_t u_plus, u_minus, u_denom;
	dmat_t p;          // m*1 : could be param->rho*ones(m,1)
	dmat_t q;          // n*1 : could be ones(n,1)
	const size_t &m, &d, &n, &k;
	const double &rho, &abar;
	double uniform_p, uniform_q;
	double abar_qTv;
	dvec_t XTp;
	rng_t rng;
	std::vector<unsigned> perm;
	leml_rankone_solver(const prob_t *prob, const param_t *param, smat_t &Dplus, dmat_t &B): prob(prob), param(param), Y(*(prob->Y)), X(*(prob->X)), H(*(prob->H)), m(prob->m), d(prob->d), n(prob->n), k(prob->k), rho(param->rho), abar(param->abar), Dplus(Dplus), B(B) { // {{{
		bc = dvec_t(m);
		k_buf = dvec_t(k);
		m_buf = dvec_t(m);
		u_plus = dvec_t(d);
		u_minus = dvec_t(d);
		u_denom = dvec_t(d);
		if(rho != 0.0) { // {{{
			p = dmat_t(m,1,COLMAJOR);
			q = dmat_t(n,1,COLMAJOR);
			for(size_t i = 0; i < m; i++)
				p.at(i,0) = (val_type)rho;
			for(size_t j = 0; j < n; j++)
				q.at(j,0) = (val_type)1.0;
			// determine uniform_p/uniform_q{{{
			uniform_q = q.at(0,0);
			for(size_t j = 1; j < n; j++) {
				if(q.at(j,0) != uniform_q) {
					uniform_q = 0.0;
					break;
				}
			}
			if(uniform_q) {
				uniform_p = (p.at(0,0)*=uniform_q);
				for(size_t i = 1; i < m; i++) {
					p.at(i,0)*=uniform_q;
					if(p.at(i,0) != uniform_p) {
						uniform_p = 0.0;
					}
				}
				uniform_q = 1.0;
				for(size_t j = 0; j < n; j++)
					q.at(j,0) = 1.0;
			} else {
				uniform_p = p.at(0,0);
				for(int i = 1; i < m; i++) {
					if(p.at(i,0) != uniform_p) {
						uniform_p = 0.0;
						break;
					}
				}
			}
			// }}}
			printf("uniform_p %g uniform_q %g\n", uniform_p, uniform_q);
		} else {
			p = dmat_t(m,1,COLMAJOR);
			q = dmat_t(n,1,COLMAJOR);
			for(size_t i = 0; i < m; i++)
				p.at(i,0) = 0;
			for(size_t j = 0; j < n; j++)
				q.at(j,0) = 0;
		} // }}}
		abar_qTv = 0;
		XTp = dvec_t(d);
		if(X.is_identity()) // {{{
			for(size_t f = 0; f < d; f++)
				XTp[f] = p.at(f,0);
		else if(X.is_dense()) {
			const dmat_t &dX = X.get_dense();
#pragma omp parallel for schedule(static)
			for(size_t f = 0; f < d; f++) {
				double sum = 0;
				for(size_t i = 0; i < m; i++)
					sum += dX.at(i,f)*p.at(i,0);
				XTp[f] = sum;
			}
		} else if(X.is_sparse()) {
			const smat_t &spX = X.get_sparse();
#pragma omp parallel for schedule(static)
			for(size_t f = 0; f < d; f++) {
				double sum = 0;
				for(size_t idx = spX.col_ptr[f]; idx != spX.col_ptr[f+1]; idx++) {
					size_t i = spX.row_idx[idx];
					sum += spX.val[idx]*p.at(i,0);
				}
				XTp[f] = sum;
			}
		} // }}}

		perm.resize(d);
		for(size_t f = 0; f < d; f++)
			perm[f] = f;
	} // }}}

	double get_C(size_t i, size_t j) { // {{{
		return p.data()[i]*q.data()[j];
		if(uniform_p && uniform_q) return uniform_p*uniform_q;
		if(uniform_p) return uniform_p * q.data()[j];
		if(uniform_q) return p.data()[i] * uniform_q;
	} // }}}
	void init_prob(int cur_t) { // {{{
		const dvec_t &v = H[cur_t];
		const bool debug = false;
		if(debug) printf("v %g\n", norm(dmat_t(v)));
		// compute k = H'*v, where v = H[t], and abar_qTv
		if(rho != 0.0) { // {{{
			if(uniform_q) {
#pragma omp parallel for schedule(static)
				for(size_t t = 0; t < k; t++)
					k_buf[t] = uniform_q * do_dot_product(H[t], v);
			} else {
#pragma omp parallel for schedule(static)
				for(size_t t = 0; t < k; t++) {
					double sum = 0.0;
					for(size_t j = 0; j < n; j++)
						sum += q.at(j,0) * H.at(j,t) * v[j];
					k_buf[t] = sum;
				}
			}
			if(abar != 0.0) {
				abar_qTv = 0;
				for(size_t j = 0; j < n; j++)
					abar_qTv += v[j]*q.at(j,0);
				abar_qTv *= abar;
			}
		} // }}}
		if(debug) printf("k_buf %g\n", norm(dmat_t(k_buf)));
		// compute bc
		double c0 = k_buf[cur_t];
#pragma omp parallel for schedule(dynamic,64)
		for(size_t i = 0; i < m; i++) { // {{{
			double cihat = 0;
			if(uniform_q) {
				for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
					size_t j = Y.col_idx[idx];
					cihat += v[j]*v[j];
				}
				cihat *= (1-p.at(i,0)*uniform_q);
			} else {
				for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
					size_t j = Y.col_idx[idx];
					cihat += (1-get_C(i,j))*v[j]*v[j];
				}
			}
			bc[i] = (cihat+c0*p.at(i,0));
		} // }}}
		if(debug) printf("bc %g\n", norm(dmat_t(bc)));

		// compute u_denom // {{{
		if(X.is_identity()) { // d == m
#pragma omp parallel for schedule(static)
			for(size_t f = 0; f < d; f++)
				u_denom[f] = param->lambda + bc[f];
		} else if(X.is_dense()) {
			const dmat_t &dX = X.get_dense();
#pragma omp parallel for schedule(static)
			for(size_t f = 0; f < d; f++) {
				double sum = 0;
				for(size_t i = 0; i < m; i++)
					sum += bc[i]*dX.at(i,f)*dX.at(i,f);
				u_denom[f] = sum + param->lambda;
			}

		} else if(X.is_sparse()) {
			const smat_t &spX = X.get_sparse();
#pragma omp parallel for schedule(dynamic,64)
			for(size_t f = 0; f < d; f++) {
				double sum = 0;
				for(size_t idx = spX.col_ptr[f]; idx != spX.col_ptr[f+1]; idx++) {
					size_t i = spX.row_idx[idx];
					const val_type &Xif = spX.val[idx];
					sum += bc[i]*Xif*Xif;
				}
				u_denom[f] = sum + param->lambda;
			}
		} // }}}
		if(debug) printf("u_denom %g\n", norm(dmat_t(u_denom)));

		// compute u_plus = -X'*D^+*v
		//   a) m_buf = -D^+*V // {{{
		if(rho != 1.0) {
#pragma omp parallel for schedule(dynamic,64)
			for(size_t i = 0; i < m; i++) {
				double sum = 0;
				for(size_t idx = Dplus.row_ptr[i]; idx != Dplus.row_ptr[i+1]; idx++) {
					size_t j = Dplus.col_idx[idx];
					sum += Dplus.val_t[idx] * v[j];
				}
				m_buf[i] = -sum;
			}
		} else {
			// if rho == 0 then (-D^+)_ij = (Y - abar)_ij
#pragma omp parallel for schedule(dynamic,64)
			for(size_t i = 0; i < m; i++) {
				double sum = 0;
				for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
					size_t j = Y.col_idx[idx];
					sum += (Y.val_t[idx]-abar) * v[j];
				}
				m_buf[i] = -sum;
			}
		} // }}}
		//    b) u_plus = X'*m_buf // {{{
		if(X.is_identity()) {
			u_plus.assign(m_buf);
		} else if(X.is_dense()) {
			const dmat_t &dX = X.get_dense();
#pragma omp parallel for schedule(static)
			for(size_t f = 0; f < d; f++) {
				u_plus[f] = 0;
				for(size_t i = 0; i < m; i++)
					u_plus[f] += dX.at(i,f) * m_buf[i];
			}
		} else if(X.is_sparse()) {
			const smat_t &spX = X.get_sparse();
#pragma omp parallel for schedule(dynamic,64)
			for(size_t f = 0; f < d; f++) {
				u_plus[f] = 0;
				for(size_t idx = spX.col_ptr[f]; idx != spX.col_ptr[f+1]; idx++) {
					size_t i = spX.row_idx[idx];
					u_plus[f] += spX.val[idx] * m_buf[i];
				}
			}
		} // }}}
		if(debug) printf("uplus %g\n", norm(dmat_t(u_plus)));

		// compute u_minus
		if(rho != 0.0) { // {{{
#pragma omp parallel for schedule(static)
			for(size_t i = 0; i < m; i++) {
				double sum = 0;
				for(size_t t = 0; t < k; t++) if(t != cur_t)
					sum += B.at(i,t) * k_buf[t];
				m_buf[i] = sum;
			}
			if(uniform_p==0) {
#pragma omp parallel for schedule(static)
				for(size_t i = 0; i < m; i++)
					m_buf[i] *= p.at(i,0);
			}
			if(X.is_identity()) {
				if(uniform_p != 0) {
					for(size_t i = 0; i < m; i++)
						m_buf[i] *= uniform_p;
				}
				u_minus.assign(m_buf);
				if(abar != 0)
					do_axpy(-abar_qTv, XTp, u_minus);
			} else if(X.is_dense()) {
				const dmat_t &dX = X.get_dense();
#pragma omp parallel for schedule(static)
				for(size_t f = 0; f < d; f++) {
					double sum = 0;
					for(size_t i = 0; i < m; i++)
						sum += m_buf[i] * dX.at(i,f);
					if(uniform_p != 0)
						sum *= uniform_p;
					u_minus[f] = sum;
					if(abar != 0)
						u_minus[f] -= abar_qTv * XTp[f];
				}

			} else if(X.is_sparse()) {
				const smat_t &spX = X.get_sparse();
#pragma omp parallel for schedule(dynamic,64)
				for(size_t f = 0; f < d; f++) {
					double sum = 0;
					for(size_t idx = spX.col_ptr[f]; idx != spX.col_ptr[f+1]; idx++) {
						size_t i = spX.row_idx[idx];
						sum += m_buf[i] * spX.val[idx];
					}
					if(uniform_p != 0)
						sum *= uniform_p;
					u_minus[f] = sum;
					if(abar != 0)
						u_minus[f] -= abar_qTv * XTp[f];
				}
			}
		} else { // p_i, q_j = 0
#pragma omp parallel for schedule(static)
				for(size_t f = 0; f < d; f++)
					u_minus[f] = 0;
		} // }}}
		if(debug) printf("uminus %g\n", norm(dmat_t(u_minus)));
	} // }}}
	// return fun_dec
	double solve(dvec_t &u, int cur_t, int inner_iter=1) { // {{{
		double cur_fundec = 0;
		const double lambda = param->lambda;
		if(X.is_identity()) {
#pragma omp parallel for schedule(static) reduction(+:cur_fundec)
			for(size_t f = 0; f < d; f++) {
				double delta = -(lambda*u[f]+u_plus[f]+(rho!=0? u_minus[f] : 0.0)+u[f]*bc[f])/u_denom[f];
				u[f] += delta;
				cur_fundec += delta*delta*u_denom[f];
			}
		} else {
			for(int inner = 0; inner < inner_iter; inner++) {
				rng.shuffle(perm.begin(), perm.end());
				for(size_t ff = 0; ff < d; ff++) {
					size_t f = perm[ff];
					// compute delta
					double delta = lambda*u[f]+u_plus[f]+((rho!=0)?u_minus[f] : 0.0);
					if(X.is_dense()) {
						const dmat_t &dX = X.get_dense();
#pragma omp parallel for schedule(static) reduction(+:delta)
						for(size_t i = 0; i < m; i++)
							delta += B.at(i,cur_t) * bc[i] * dX.at(i,f);
					} else if(X.is_sparse()){
						const smat_t &spX = X.get_sparse();
#pragma omp parallel for schedule(static) reduction(+:delta)
						for(size_t idx = spX.col_ptr[f]; idx < spX.col_ptr[f+1]; idx++) {
							size_t i = spX.row_idx[idx];
							delta += B.at(i,cur_t) * bc[i] * spX.val[idx];
						}
					}
					delta = -delta/u_denom[f];
					u[f] += delta;
					cur_fundec += delta*delta*u_denom[f];

					// update B(:,cur_t)
					if(X.is_dense()) {
						const dmat_t &dX = X.get_dense();
#pragma omp parallel for schedule(static)
						for(size_t i = 0; i < m; i++)
							B.at(i,cur_t) += delta * dX.at(i,f);
					} else if(X.is_sparse()) {
						const smat_t &spX = X.get_sparse();
#pragma omp parallel for schedule(static)
						for(size_t idx = spX.col_ptr[f]; idx < spX.col_ptr[f+1]; idx++)
							B.at(spX.row_idx[idx],cur_t) += delta * spX.val[idx];
					}
				}
			}
		}
		return cur_fundec;
	} // }}}
}; // }}}

static void update_Dplus(smat_t &Dplus, double a, const dvec_t &Au, const dvec_t &Bv) { // {{{
	if(a != 0.0) {
#pragma omp parallel for schedule(dynamic,64)
		for(size_t i = 0; i < Dplus.rows; i++) {
			for(size_t idx = Dplus.row_ptr[i]; idx != Dplus.row_ptr[i+1]; idx++) {
				size_t j = Dplus.col_idx[idx];
				Dplus.val_t[idx] += a * Au[i]*Bv[j];
			}
		}
	}
} // }}}

// 0.5* {sum_{ij\in\Omega+} (Yij-barYij) - rho(Yij-abar)^2, barYij = <W'ai, H'bj>}
static double compute_loss_pos(const smat_t& Y, const smat_t &Dplus, double rho, double abar) { // {{{
	double loss_pos = 0;
#pragma omp parallel for schedule(static) reduction(+:loss_pos)
	for(size_t idx = 0; idx < Y.nnz; idx++) {
		const val_type &Dij = Dplus.val_t[idx];
		const val_type &Yij = Y.val_t[idx];
		val_type barYij=0;
		if(rho == 0.0 || rho == 1.0) {
			barYij = Yij - Dij;
		} else {
			barYij = (1.0/(rho-1))*(Dij-Yij+rho*abar);
		}
		loss_pos += (Yij-barYij)*(Yij-barYij) - rho*(barYij-abar)*(barYij-abar);
	}
	return 0.5*loss_pos;
} // }}}

static double compute_loss_pos(const smat_t &Y, const dmat_t &AW, const dmat_t &BH, double rho, double abar) { // {{{
	double loss_pos = 0;
#pragma omp parallel for schedule(dynamic,64) reduction(+:loss_pos)
	for(size_t i = 0; i < Y.rows; i++) {
		double local_loss_pos=0;
		for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
			size_t j = Y.col_idx[idx];
			const val_type &Yij = Y.val_t[idx];
			double barYij = 0;
			for(size_t t = 0; t < AW.cols; t++)
				barYij += AW.at(i,t)*BH.at(j,t);
			local_loss_pos += (Yij-barYij)*(Yij-barYij) - rho*(abar-barYij)*(abar-barYij);
		}
		loss_pos += local_loss_pos;
	}
	return 0.5*loss_pos;
} // }}}

// loss on all entries 0.5*(||diag(p)^1/2*(abar*ones(m,n) - AW*BH')*diag(q)^1/2||^2)
//  = 0.5* {abar^2*sum(p)*sum(q) + <AW, diag(p) AW*BH'*diag(q)*BH> - 2*abar* p'*AW*BH'*q}
static double compute_loss_neg(const dmat_t &AW, const dmat_t &BH, double rho, double abar) { // {{{
	size_t m = AW.rows, n = BH.rows, k = AW.cols;
	double loss_neg = 0;
	// compute abar^2*sum(p)*sum(q) - 2*abar* p'*AW*BH'*q
	if(abar) { // {{{
		double sump = rho*m, sumq = ((rho!=0)?(double)n:0);
		loss_neg += abar*abar*sump*sumq;
		for(size_t t = 0; t < k; t++) {
			double tmp1 = 0, tmp2 = 0;
			if(rho!=0) {
				for(size_t i = 0; i < m; i++)
					tmp1 += AW.at(i,t);
				for(size_t j = 0; j < n; j++)
					tmp2 += BH.at(j,t);
			}
			loss_neg -= 2.0*abar*rho*tmp1*tmp2;
		}
	} // }}}
	// compute <AW, diag(p) AW*BH'*diag(q)*BH
	if(rho) loss_neg += rho*do_dot_product(AW*(BH.transpose()*BH), AW);
	return 0.5*loss_neg;
} // }}}

void ccdr1_imf(imf_prob_t *prob, imf_param_t *param, dmat_t *ptr_W, dmat_t *ptr_H, imf_prob_t *test_prob, info_t *info, double *rmse) { // {{{
	smat_t &Y = *(prob->Y), Yt = Y.transpose();
	gmat_t &A = *(prob->A);
	gmat_t &B = *(prob->B);
	dmat_t &W = *ptr_W;
	dmat_t &H = *ptr_H;
	const size_t &m = prob->m, &n = prob->n, &da = prob->da, &db = prob->db;
	const size_t &k = prob->k;
	const double rho = param->solver_type == CCDR1? 0: param->rho;
	const double abar = param->abar;

	// CCD++ uses colmajor W and H
	if(W.is_rowmajor()) W.to_colmajor();
	if(H.is_rowmajor()) H.to_colmajor();

	if(A.is_dense() && A.get_dense().is_rowmajor())
		printf("Warning: Use colmajored feature matrix A to have better locality\n");
	if(B.is_dense() && B.get_dense().is_colmajor())
		printf("Warning: Use colmajored feature matrix B to have better locality\n");

	imf_ranker_t ranker;
	bool do_rmse = (param->do_predict&1)!=0;
	bool do_ranking = (param->do_predict&2)!=0;

	omp_set_num_threads(param->threads);

	dmat_t AW, BH;
	if(A.is_dense()) { // {{{
		AW = dmat_t(m, k, COLMAJOR);
	} else if(A.is_identity()) {
		//AW = dmat_t(m, k, COLMAJOR);
		AW = W.get_view();
	} else if(A.is_sparse()) {
		AW = dmat_t(m, k, COLMAJOR);
	}

	if(B.is_dense()) {
		BH = dmat_t(n, k, COLMAJOR);
	} else if(B.is_identity()) {
		BH = H.get_view();
		//BH = dmat_t(n, k, COLMAJOR);
	} else if(B.is_sparse()) {
		BH = dmat_t(n, k, COLMAJOR);
	} // }}}

	if(test_prob && test_prob->Y)
		ranker = imf_ranker_t(*test_prob, *param);

	dvec_t nnzY_buf1(Y.nnz), nnzY_buf2(Y.nnz);
	smat_t Dplus = Y.get_view();  // if rho==1, Dplus = W
	Dplus.val = nnzY_buf1.data();
	Dplus.val_t = nnzY_buf2.data();
	smat_t Dplust = Dplus.transpose();

	double Itime=0, W_init_time=0, Wtime=0, H_init_time=0, Htime=0, start_time=0;
	const double eps = param->eps;
	double reg=0, loss=0, oldobj = 0;

	// Initialize Dplus
#pragma omp parallel for schedule(static) reduction(+:oldobj)
	for(size_t idx = 0; idx < Y.nnz; idx++) {
		if(rho != 1.0) {
			Dplus.val[idx] = Y.val[idx] - rho*abar;
			Dplus.val_t[idx] = Y.val_t[idx] - rho*abar;
		} else {
			// if rho == 1, Dplus reduces to residual R = Y - AWH'B'
			Dplus.val[idx] = Y.val[idx];
			Dplus.val_t[idx] = Y.val_t[idx];
		}
		oldobj += Y.val[idx]*Y.val[idx] - rho*abar*abar;
	}
	oldobj += rho*m*n*abar*abar;
	oldobj *= 0.5;

	printf("norm-- %g %g\n", do_dot_product(nnzY_buf1,nnzY_buf1), do_dot_product(nnzY_buf2, nnzY_buf2));

	if(param->verbose != 0) {
		printf("|W0| (%ld %ld) = %.6f\n", W.rows, W.cols, norm(W));
		printf("|H0| (%ld %ld) = %.6f\n", H.rows, H.cols, norm(H));
	}
	const bool all_true = true;
	const double scale = ((rho!=1.0)? (1.0-rho): 1.0);
	if(all_true || param->warm_start) { // {{{
		gmat_x_dmat(B, H, BH);
		gmat_x_dmat(A, W, AW);
		reg += do_dot_product(W,W);
		reg += do_dot_product(H,H);
		for(size_t t = 0; t < k; t++) {
			update_Dplus(Dplus, -scale, AW[t], BH[t]);
			update_Dplus(Dplust, -scale, BH[t], AW[t]);
		}
	} else {
#pragma omp parallel for schedule(static)
		for(size_t t = 0; t < k; t++) {
			memset(H[t].data(), 0, sizeof(val_type)*H.rows);
			if(!B.is_identity())
				memset(BH[t].data(), 0, sizeof(val_type)*n);
		}
		gmat_x_dmat(A, W, AW);
		reg += do_dot_product(W,W);
		oldobj += 0.5*param->lambda*reg;
	} // }}}
	printf("|AW| (%ld %ld) = %.6f\n", AW.rows, AW.cols, norm(AW));
	printf("|BH| (%ld %ld) = %.6f\n", BH.rows, BH.cols, norm(BH));

	bilinear_prob_t subprob_w(&Y, &A, &BH);
	bilinear_prob_t subprob_h(&Yt,&B, &AW);
	leml_rankone_solver W_solver(&subprob_w, param, Dplus, AW);
	leml_rankone_solver H_solver(&subprob_h, param, Dplust, BH);
	printf("norm %g %g\n", do_dot_product(nnzY_buf1,nnzY_buf1), do_dot_product(nnzY_buf2, nnzY_buf2));
	for(int oiter = 1; oiter <= param->maxiter; oiter++) {
		double gnorm = 0, initgnorm=0;
		double rankfundec = 0;
		double fundec_max = 0;
		int early_stop = 0;
		for(size_t tt = 0; tt < k; tt++) {
			size_t t = tt;
			if(early_stop >= 5) break;
			dvec_t &u = W[t], &v = H[t];
			dvec_t &Au = AW[t], &Bv = BH[t];

			// Update Dplus
			start_time = omp_get_wtime();
			if (all_true || param->warm_start || oiter > 1) {
				update_Dplus(Dplus, scale, Au, Bv);
				update_Dplus(Dplust, scale, Bv, Au);
			}
			Itime += omp_get_wtime() - start_time;

			gnorm = 0, initgnorm = 0;
			double innerfundec_cur = 0, innerfundec_max = 0;
			int maxit = param->maxinneriter;
			for(int iter = 1; iter <= maxit; iter++) {
				// Update H[t]
				start_time = omp_get_wtime();
				H_solver.init_prob(t);
				H_init_time += omp_get_wtime() - start_time;
				start_time = omp_get_wtime();
				innerfundec_cur = H_solver.solve(v, t, param->max_cg_iter);
				Htime += omp_get_wtime() - start_time;

				//printf("vnorm %p %g\n", v.data(), do_dot_product(v,v));

				// Update W[t]
				start_time = omp_get_wtime();
				W_solver.init_prob(t);
				W_init_time += omp_get_wtime() - start_time;
				start_time = omp_get_wtime();
				innerfundec_cur += W_solver.solve(u, t, param->max_cg_iter);
				Wtime += omp_get_wtime() - start_time;
				//printf("unorm %g\n", do_dot_product(u,u));

				if(innerfundec_cur < (fundec_max*eps)) {
					if(iter == 1) early_stop+=1;
					break;
				}
				rankfundec += innerfundec_cur;
				innerfundec_max = std::max(innerfundec_max, innerfundec_cur);
				// the fundec of the first inner iter of the first rank of the first outer iteration could be too large!!
				if(!(oiter==1 && t == 0 && iter==1))
					fundec_max = std::max(fundec_max, innerfundec_cur);
			}

			start_time = omp_get_wtime();
			update_Dplus(Dplus, -scale, Au, Bv);
			update_Dplus(Dplust, -scale, Bv, Au);
			Itime += omp_get_wtime() - start_time;

			if(param->verbose == 2) { // {{{
				printf("IMF-iter %d tt %d W %.5g H %.5g Time %.5g",
						//iter, Wtime, Htime, Wtime+Htime);
					oiter, tt, norm(W),norm(H), Itime+Wtime+W_init_time+Htime+H_init_time);
				double reg_w = do_dot_product(W,W);
				double reg_h = do_dot_product(H,H);
				double loss_pos = compute_loss_pos(Y, Dplus, rho, abar);
				double loss_pos2= compute_loss_pos(Y, AW, BH, rho, abar);
				double loss_neg = compute_loss_neg(AW, BH, rho, abar);
				double obj = loss_pos+loss_neg+0.5*param->lambda*(reg_w+reg_h);
				printf(" loss %g (%g %g) reg %g obj %.5g dec %.5g", loss_pos+loss_neg, loss_pos, loss_pos2, reg_w+reg_h, obj, oldobj-obj);
				oldobj = obj;

				if(test_prob!=NULL && test_prob->Y!=NULL) { // {{{
					ranker.preprocess(W,H);
					if(do_rmse) {
						double tmp_rmse = ranker.get_rmse();
						if(rmse!=NULL) *rmse = tmp_rmse;
						printf(" rmse %.5g", tmp_rmse);
					}
					if(do_ranking) {
						info_t tmp_info = ranker.get_info(param->top_p);
						if(info!=NULL) *info= tmp_info;
						tmp_info.print();
					}
				} // }}}
				puts("");
				fflush(stdout);
			} // }}}
		}
		if(param->verbose == 1) { // {{{
			printf("IMF-iter %d W %.5g H %.5g Time %.5g",
					//iter, Wtime, Htime, Wtime+Htime);
				oiter, norm(W),norm(H), Itime+Wtime+W_init_time+Htime+H_init_time);
			double reg_w = do_dot_product(W,W);
			double reg_h = do_dot_product(H,H);
			double loss_pos = compute_loss_pos(Y, Dplus, rho, abar);
			double loss_pos2= compute_loss_pos(Y, AW, BH, rho, abar);
			double loss_neg = compute_loss_neg(AW, BH, rho, abar);
			double obj = loss_pos+loss_neg+0.5*param->lambda*(reg_w+reg_h);
			printf(" loss %g (%g %g) reg %g obj %.5g dec %.5g", loss_pos+loss_neg, loss_pos, loss_pos2, reg_w+reg_h, obj, oldobj-obj);
			oldobj = obj;

			if(test_prob!=NULL && test_prob->Y!=NULL) { // {{{
				ranker.preprocess(W,H);
				if(do_rmse) {
					double tmp_rmse = ranker.get_rmse();
					if(rmse!=NULL) *rmse = tmp_rmse;
					printf(" rmse %.5g", tmp_rmse);
				}
				if(do_ranking) {
					info_t tmp_info = ranker.get_info(param->top_p);
					if(info!=NULL) *info= tmp_info;
					tmp_info.print();
				}
			} // }}}
			puts("");
			fflush(stdout);
		} // }}}
	}
} // }}}
