#include "bilinear.h"


// print function {{{
static void print_string_stdout(const char *s) { fputs(s,stdout); fflush(stdout); }
static void print_null(const char *){}
typedef void (*print_fun_ptr)(const char *);
template<typename T>
static print_fun_ptr get_print_fun(T *param) { // {{{
	if(param->verbose == 0) return print_null;
	else if(param->verbose == 1) return print_null;
	else return print_string_stdout;
} // }}}
// }}}

static double norm(const dmat_t &W) { // {{{
	val_type *w = const_cast<val_type*>(W.data());
	double ret = do_dot_product(w,w, W.rows*W.cols);
	return sqrt(ret);
} // }}}

static double sum(const dvec_t &v) { // {{{
	double ret = 0;
	for(size_t i = 0; i < v.len; i++)
		ret += v[i];
	return ret;
} // }}}

static void partial_dmat_x_dmat(const smat_t &Y, major_t major, const dmat_t &A, const dmat_t &B,  dvec_t &z) { // {{{
	assert(z.len == Y.nnz);
	assert(A.rows == Y.rows && B.cols == Y.cols && A.cols == B.rows);
	if(major==COLMAJOR)
		return partial_dmat_x_dmat(Y.transpose(), ROWMAJOR, B.transpose(), A.transpose(), z);
#pragma omp parallel for schedule(dynamic,32)
	for(size_t i = 0; i < Y.rows; i++) {
		for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
			double sum = 0;
			size_t j = Y.col_idx[idx];
			for(size_t s = 0; s < A.cols; s++)
				sum += A.at(i,s)*B.at(s,j);
			z[idx] = sum;
		}
	}
} // }}}

// l2r + squared-L2 loss + full observation + general X
class l2r_ls_fY_gX : public function<val_type> {  // {{{
	protected:
		// {{{
		typedef bilinear_prob_t prob_t;
		typedef bilinear_param_t param_t;
		const prob_t* prob;
		const param_t* param;
		const smat_t &Y;
		double trYTY;
		dmat_t HTH;        // k*k
		dmat_t YH;         // m*k
		dmat_t XTYH;       // d*k
		dmat_t XTX;        // d*d
		dmat_t mk_buf;     // m*k
		dmat_t dk_buf;     // d*k
		dmat_t kk_buf;     // k*k
		dmat_t W, G, S, HS;// d*k view
		const size_t &m, &d, &n, &k;
		bool maintain_XTX;
		// }}}
	public:
		l2r_ls_fY_gX(const prob_t* prob, const param_t* param): prob(prob), param(param), Y(*(prob->Y)), m(prob->m), d(prob->d), n(prob->n), k(prob->k) { // {{{
			HTH = dmat_t(k, k, ROWMAJOR);
			YH = dmat_t(m, k, ROWMAJOR);
			maintain_XTX = false;    // XXX Check here
			if(prob->X->is_dense()) {
				XTYH = dmat_t(d, k, ROWMAJOR);
				dk_buf = dmat_t(d, k, ROWMAJOR);
				if(maintain_XTX) {
					XTX = dmat_t(d, d, ROWMAJOR);
					dmat_t& X = prob->X->get_dense();
					dmat_x_dmat(X.transpose(), X, XTX);
					kk_buf = dmat_t(k, k, ROWMAJOR);
				} else {
					mk_buf = dmat_t(m, k, ROWMAJOR);
				}
			} else if(prob->X->is_identity()){
				dk_buf = dmat_t(d, k, ROWMAJOR);
				kk_buf = dmat_t(k, k, ROWMAJOR);
			} else if(prob->X->is_sparse()) {
				XTYH = dmat_t(d, k, ROWMAJOR);
				dk_buf = dmat_t(d, k, ROWMAJOR);
				mk_buf = dmat_t(m, k, ROWMAJOR);
			}
			W = G = S = HS = dk_buf.get_view();
			trYTY = do_dot_product(Y.val, Y.val, Y.nnz);
		} // }}}
		void init() { // {{{
			const dmat_t &H = *(prob->H);
			assert(H.is_rowmajor());
			dmat_x_dmat(H.transpose(), H, HTH);
			//printf(" BH=> %p buf %p %g\n", &H, H.buf, norm(HTH));
			smat_x_dmat(Y, H, YH);
			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X.transpose(), YH, XTYH);
			} else if(prob->X->is_identity()) {
				XTYH = YH.get_view();
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X.transpose(), YH, XTYH);
			}
		} // }}}
		int get_nr_variable(void) {return (int)(d*k);}
		double fun(val_type* w) { // {{{
			W.buf = w;
			double obj = trYTY;
			dmat_t &XTXW = dk_buf, &WTXTXW = kk_buf;
			// re-use dk_buf to compute XTXWHTH + lambda*W
			if(prob->X->is_dense()) {
				if(maintain_XTX) {
					dmat_x_dmat(XTX, W, XTXW);
					dmat_x_dmat(W.transpose(), XTXW, WTXTXW);
					obj += do_dot_product(WTXTXW.data(), HTH.data(), k*k);
					obj -= 2.0*do_dot_product(XTYH.data(), W.data(), d*k);
					obj += param->lambda*do_dot_product(w, w, d*k);
					obj *= 0.5;
				} else {
					obj -= 2.0*do_dot_product(W.data(), XTYH.data(), d*k);
					// re-use dk_buf to compute XTXWHTH + lambda*W
					dmat_t& X = prob->X->get_dense();
					dmat_x_dmat(W, HTH, dk_buf);
					dmat_x_dmat(X, dk_buf, mk_buf);
					dmat_x_dmat(X.transpose(), mk_buf, dk_buf);
					do_axpy((val_type)(param->lambda), W.data(), dk_buf.data(), d*k);
					// compute <W,XTXWHTH + lambda*W>
					obj += do_dot_product(W.data(), dk_buf.data(), d*k);
					obj *= 0.5;
				}
			} else if(prob->X->is_identity()) {
				dmat_x_dmat(W.transpose(), W, WTXTXW);
				obj += do_dot_product(WTXTXW.data(), HTH.data(), k*k);
				obj -= 2.0*do_dot_product(XTYH.data(), W.data(), d*k);
				obj += param->lambda*do_dot_product(w, w, d*k);
				obj *= 0.5;
			} else if(prob->X->is_sparse()) {
				obj -= 2.0*do_dot_product(W.data(), XTYH.data(), d*k);
				// re-use dk_buf to compute XTXWHTH + lambda*W
				smat_t& X = prob->X->get_sparse();
				dmat_x_dmat(W, HTH, dk_buf);
				smat_x_dmat(X, dk_buf, mk_buf);
				smat_x_dmat(X.transpose(), mk_buf, dk_buf);
				do_axpy((val_type)(param->lambda), W.data(), dk_buf.data(), d*k);
				// compute <W,XTXWHTH + lambda*W>
				obj += do_dot_product(W.data(), dk_buf.data(), d*k);
				obj *= 0.5;
			}

			double reg_graph = 0;
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				reg_graph = 0.5*trace_dmat_T_smat_dmat(W, L, W);
			}
			return obj+reg_graph*param->lambda_graph;
		} // }}}
		void grad(val_type *w, val_type *g) { // {{{
			W.buf = w; G.buf = g;
			// assume fun(w) just being called
			if(prob->X->is_dense()) {
				if(maintain_XTX) {
					dmat_t &XTXW = dk_buf;
					do_copy(XTYH.data(), g, d*k);
					do_axpy((val_type)(-param->lambda), w, g, d*k);
					dmat_x_dmat((val_type)1.0, XTXW, HTH, (val_type)(-1.0), G);
				} else {
					dmat_t &XTXWHTH_lambdaW = dk_buf;
					do_copy(XTXWHTH_lambdaW.data(), g, d*k);
					do_axpy((val_type)(-1.0), XTYH.data(), g, d*k);
				}
			} else if(prob->X->is_identity()) {
				dmat_t &XTXW = W;
				do_copy(XTYH.data(), g, d*k);
				do_axpy((val_type)(-param->lambda), w, g, d*k);
				dmat_x_dmat((val_type)1.0, XTXW, HTH, (val_type)(-1.0), G);
			} else if(prob->X->is_sparse()) {
				dmat_t &XTXWHTH_lambdaW = dk_buf;
				do_copy(XTXWHTH_lambdaW.data(), g, d*k);
				do_axpy((val_type)(-1.0), XTYH.data(), g, d*k);
			}
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(param->lambda_graph, L, W, 1.0, G, G);
			}
		} // }}}
		void Hv(val_type *s, val_type *Hs) { // {{{
			S.buf = s; HS.buf = Hs;
			if(prob->X->is_dense()) {
				if(maintain_XTX) {
					dmat_t &XTXS = dk_buf;
					dmat_x_dmat(XTX, S, XTXS);
					do_copy(s, Hs, d*k);
					dmat_x_dmat((val_type)1.0, XTXS, HTH, (val_type)(param->lambda), HS);
				} else {
					// vec(XTXSHTH + lambda S)
					dmat_t &X = prob->X->get_dense();
					dmat_x_dmat(X, S, mk_buf);
					dmat_x_dmat(X.transpose(), mk_buf, dk_buf);
					do_copy(s, Hs, d*k);
					dmat_x_dmat((val_type)1.0, dk_buf, HTH, (val_type)(param->lambda), HS);
				}
			} else if (prob->X->is_identity()) {
				do_copy(s, Hs, d*k);
				dmat_x_dmat((val_type)1.0, S, HTH, (val_type)(param->lambda), HS);
			} else if (prob->X->is_sparse()) {
				// vec(XTXSHTH + lambda S)
				smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, S, mk_buf);
				smat_x_dmat(X.transpose(), mk_buf, dk_buf);
				do_copy(s, Hs, d*k);
				dmat_x_dmat((val_type)1.0, dk_buf, HTH, (val_type)(param->lambda), HS);
			}
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(param->lambda_graph, L, S, 1.0, HS, HS);
			}
		} // }}}
}; // }}}

// Y with Missing Values
// routines for Y with missing values
void barXv_withXV(const smat_t& Y, const dmat_t& XV, const dmat_t& H, dvec_t& barXv) { // {{{
	partial_dmat_x_dmat(Y, ROWMAJOR, XV, H.transpose(), barXv);
} // }}}

// gr/l2r + squared-L2 loss + partial observation + general X
class l2r_ls_mY_gX : public function<val_type> {  // {{{
	protected:
		// {{{
		typedef bilinear_prob_t prob_t;
		typedef bilinear_param_t param_t;
		const prob_t* prob;
		const param_t* param;
		const smat_t &Y;
		smat_t U;          // view
		dmat_t &H;
		dmat_t mk_buf;     // m*k
		dvec_t z;          // nnz_Y
		dmat_t mk_view;    // m*k view
		dmat_t dk_view;    // d*k view
		dmat_t W, G, S, HS;// d*k view
		const size_t &m, &d, &n, &k;
		// }}}
	public:
		l2r_ls_mY_gX(const prob_t* prob, const param_t* param): prob(prob), param(param), Y(*(prob->Y)), H(*(prob->H)), m(prob->m), d(prob->d), n(prob->n), k(prob->k) { // {{{
			mk_buf = dmat_t(m, k, ROWMAJOR);
			z = dvec_t(Y.nnz);
			dk_view = W = G = S = HS = dmat_t(d, k, mk_buf.data(), ROWMAJOR);
			mk_view = mk_buf.get_view();
			U = Y.get_view();
		} // }}}
		int get_nr_variable(void) {return (int)(d*k);}
		// barXv = barX * v; Xv = X*v;
		void barXv(dmat_t &V, dvec_t &barXv, dmat_t &XV) { // {{{
			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X, V, XV);
			} else if(prob->X->is_identity()) {
				do_copy(V.data(), XV.data(), m*k);
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, V, XV);
			}
			barXv_withXV(Y, XV, H, barXv);
		} // }}}

		// barXTu = barXt * u + b*barXTu; UH = U * H
		void barXTu(dvec_t& u, val_type b, dmat_t& barXTU, dmat_t& UH) { // {{{
			U.val_t = u.data();
		//	dmat_t &barXTU = dk_view; barXTU.buf = barXTu;
		//	dmat_t &UH = mk_view; UH.buf = Uh;
			smat_x_dmat(U, H, UH);
			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat((val_type)1.0, X.transpose(), UH, (val_type)b, barXTU);
			} else if(prob->X->is_identity()) {
#pragma omp parallel for schedule(static,32)
				for(size_t idx = 0; idx < d*k; idx++)
					barXTU.buf[idx] = b*barXTU.buf[idx] + UH.buf[idx];
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat((val_type)1.0, X.transpose(), UH, (val_type)b, barXTU, barXTU);
			}
		} // }}}

		double fun(val_type* w) { // {{{
			W.buf = w;
			dmat_t& XW = mk_buf;
			// z = barX*w - y
			barXv(W, z, XW);
			do_axpy((val_type)(-1.0), Y.val_t, z.data(), Y.nnz);
			double loss = do_dot_product(z.data(), z.data(), Y.nnz);
			double reg = do_dot_product(w, w, d*k);
			double reg_graph = 0;
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				reg_graph = trace_dmat_T_smat_dmat(W, L, W);
			}

			//printf("> loss_pos %.10g loss_neg %.10g sum %.10g", loss_pos, loss_neg, loss_pos + loss_neg);
			double obj = 0.5*(loss + param->lambda*reg + param->lambda_graph*reg_graph);
			//printf("obj %.10g z %.10g\n", obj, norm(dmat_t(z)));
			return obj;

		} // }}}
		void grad(val_type *w, val_type *g) { // {{{
			W.buf = w; G.buf = g;
			// assume fun(w) just being called
			// z = barX*w - y
			do_copy(w, g, d*k);
			barXTu(z, (val_type)(param->lambda), G,  mk_buf);
			//printf(" grad z %.10g mk_buf %.10g\n", norm(G), norm(mk_buf));
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(param->lambda_graph, L, W, 1.0, G, G);
			}
		} // }}}
		void Hv(val_type *s, val_type *Hs) { // {{{
			S.buf = s; HS.buf = Hs;
			barXv(S, z, mk_buf);
			do_copy(s, Hs, d*k);
			//printf(" Hv s %.10g z %.10g mk_buf %.10g Hs %.10g", norm(S), norm(dmat_t(z)), norm(mk_buf), norm(HS));
			barXTu(z, (val_type)(param->lambda), HS, mk_buf);
			//printf(" mk %.10g\n", norm(mk_buf));
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(param->lambda_graph, L, S, 1.0, HS, HS);
			}
		} // }}}
}; // }}}

// gr/l2r + general loss + one-class observation + general X
class l2r_erm_oY_gX : public function<val_type> {  // {{{
	protected:
		// {{{ variable definitions
		typedef bilinear_prob_t prob_t;
		typedef bilinear_param_t param_t;
		const prob_t* prob;
		const param_t* param;
		const smat_t &Y;
		const dmat_t &H;
		smat_t Dplus;      // view of Y
		smat_t Uplus;      // view of Y
		dvec_t nnzY_buf1;  // Y.nnz: sparse (XWH')^+ or D^+ or U^+
		dvec_t nnzY_buf2;  // Y.nnz: sparse dot{D}^+ // required for other loss
		dmat_t p;          // m*1 : could be param->rho*ones(m,1)
		dmat_t q;          // n*1 : could be ones(n,1)
		dmat_t bd;         // d*1 : 2*abar*X'*p
		dmat_t bk;         // k*1 : 2*H'*q
		dmat_t HTH;        // k*k : M = H'*diag(q)*H
		dmat_t mk_buf0;    // m*k : B = X*W
		dmat_t mk_buf1;    // m*k : X*W*M
		dmat_t mk_buf2;    // m*k : dB = XdW or X*S
		dmat_t mk_buf3;    // m*k : dMbar = diag(p)*dB*M or X*S*M
		dmat_t mk_buf4;    // m*k : Bnew
		dmat_t dk_buf;     // d*k : W*M or S*M
		dmat_t dk_buf2;    // d*k : Lbar = L*W
		dmat_t dk_buf3;    // d*k : dLbar = L*dW
		//dmat_t dd_buf;     // d*d store X'*diag(p)*X
		dmat_t nk_buf;     // n*k : store diag(q)*H
		dmat_t W, G, S, HS;// d*k view
		const size_t &m, &d, &n, &k;
		const double &rho, &abar;
		double uniform_p, uniform_q;
		double delta, cur_obj, loss_pos;
		bool maintain_XTX;
		// }}}
		double get_C(size_t i, size_t j) { // {{{
			return p.data()[i]*q.data()[j];
			if(uniform_p && uniform_q) return uniform_p*uniform_q;
			if(uniform_p) return uniform_p * q.data()[j];
			if(uniform_q) return p.data()[i] * uniform_q;
		} // }}}
		// loss on observed entries loss_{ij)(Y_{ij},  xi' W hj)
		virtual double compute_loss_pos(const dvec_t &pred_val) = 0;
		virtual bool line_search_supported() {return true;}
	public:
		l2r_erm_oY_gX(const prob_t* prob, const param_t* param): prob(prob), param(param), Y(*(prob->Y)), H(*(prob->H)), m(prob->m), d(prob->d), n(prob->n), k(prob->k), rho(param->rho), abar(param->abar) { // {{{
			assert(H.is_rowmajor());
			assert(!prob->X->is_dense() || prob->X->get_dense().is_rowmajor());
			maintain_XTX = false; // Check here
			Dplus = Y.get_view();
			Uplus = Y.get_view();
			nnzY_buf1 = dvec_t(Y.nnz);
			nnzY_buf2 = dvec_t(Y.nnz);
			if(rho != 0.0) {
				p = dmat_t(m,1,COLMAJOR);
				q = dmat_t(n,1,COLMAJOR);
				for(size_t i = 0; i < m; i++)
					p.at(i,0) = (val_type)rho;
				for(size_t j = 0; j < n; j++)
					q.at(j,0) = (val_type)1.0;
				if(abar != 0.0) {
					bd = dmat_t(d, 1, COLMAJOR);
					bk = dmat_t(k, 1, COLMAJOR);
				}
				// determine uniform_p/uniform_q{{{
				uniform_q = q.at(0,0);
				for(size_t j = 1; j < n; j++) {
					if(q.at(j,0) != uniform_q) {
						uniform_q = 0.0;
						break;
					}
				}
				if(uniform_q) {
					uniform_p = (p.at(0,0)*=uniform_q);
					for(size_t i = 1; i < m; i++) {
						p.at(i,0)*=uniform_q;
						if(p.at(i,0) != uniform_p) {
							uniform_p = 0.0;
						}
					}
					uniform_q = 1.0;
					for(size_t j = 0; j < n; j++)
						q.at(j,0) = 1.0;
				} else {
					uniform_p = p.at(0,0);
					for(int i = 1; i < m; i++) {
						if(p.at(i,0) != uniform_p) {
							uniform_p = 0.0;
							break;
						}
					}
				}
				// }}}
				printf("uniform_p %g uniform_q %g\n", uniform_p, uniform_q);
			} else {
				p = dmat_t(m,1,COLMAJOR);
				q = dmat_t(n,1,COLMAJOR);
				for(size_t i = 0; i < m; i++)
					p.at(i,0) = 0;
				for(size_t j = 0; j < n; j++)
					q.at(j,0) = 0;
			}
			HTH = dmat_t(k, k, ROWMAJOR);
			mk_buf0 = dmat_t(m, k, ROWMAJOR);
			mk_buf1 = dmat_t(m, k, ROWMAJOR);
			mk_buf2 = dmat_t(m, k, ROWMAJOR);
			mk_buf3 = dmat_t(m, k, ROWMAJOR);
			mk_buf4 = dmat_t(m, k, ROWMAJOR);
			dk_buf = dmat_t(d, k, ROWMAJOR);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				dk_buf2 = dmat_t(d, k, ROWMAJOR);
				dk_buf3 = dmat_t(d, k, ROWMAJOR);
			}
			nk_buf = dmat_t(n, k, ROWMAJOR);
			W = G = S = HS = dk_buf.get_view();
		} // }}}
		void init() { // {{{
			if(rho!=0) {  // otherwise HTH = zeros(k,k)
				if(uniform_q) {
					dmat_x_dmat(H.transpose(), H, HTH);
					if(uniform_q != 1.0)
						do_scale(uniform_q, HTH);
	//				printf(" BH=> %p buf %p %g\n", &H, H.buf, norm(HTH));
				} else {
					for(size_t j = 0; j < n; j++) {
						double qj = q.at(j,0);
						for(size_t t = 0; t < k; t++)
							nk_buf.at(j,t) = qj * H.at(j,t);
					}
					dmat_x_dmat(H.transpose(), nk_buf, HTH);
				}
				if(abar != 0.0) {
					dmat_x_dmat(H.transpose(), q, bk);
					if(prob->X->is_dense()) {
						dmat_t &X = prob->X->get_dense();
						dmat_x_dmat(X.transpose(), p, bd);
					} else if(prob->X->is_identity()) {
						bd.assign(p);
					} else if(prob->X->is_sparse()) {
						smat_t &X = prob->X->get_sparse();
						smat_x_dmat(X.transpose(), p, bd);
					}
					dmat_x_dmat(bd, bk.transpose(), dk_buf);
					do_scale(abar, dk_buf);
				}
			}
			delta = cur_obj = loss_pos = 0;
		} // }}}
		int get_nr_variable(void) {return (int)(d*k);}
		// post conditions after fun() and line_search():
		//     mk_buf0 = B = X*W
		//     mk_buf1 = Mbar = diag(p)*XWHTH = diag(p)*XWM
		//     dk_buf2 = Lbar = L*W
		//     nnzY_buf1 = d1loss
		//     nnzY_buf2 = d2loss
		//     loss_pos = sum_{(i,j)\in \Omega^+} \ell^+_{ij}
		//     cur_obj = current objective function value
		double fun(val_type *w) { // {{{
			dmat_t &M=HTH, &XW = mk_buf0, &XWM = mk_buf1;
			dvec_t &d1loss = nnzY_buf1, &d2loss = nnzY_buf2;
			const dvec_t y(Y.nnz, Y.val_t);
			W.buf = w;
			if(prob->X->is_dense()) {
				dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X, W, XW);
				dmat_x_dmat(XW, M, XWM);
			} else if(prob->X->is_identity()) {
				dmat_x_dmat(W, M, XWM);
				XW.assign(W);
			} else if(prob->X->is_sparse()) {
				smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, W, XW);
				dmat_x_dmat(XW, M, XWM);
			}

			// loss on observed entries loss_{ij)(Y_{ij},  xi' W hj)
			partial_dmat_x_dmat(Y, ROWMAJOR, XW, H.transpose(), d1loss);
			loss_pos = compute_loss_pos(d1loss);

			// loss on all entries (||diag(p)^1/2*(abar*ones(m,n) - XWH')*diag(q)^1/2||^2)
			//  = abar^2*sum(p)*sum(q) + <XW, diag(p) XWH'H> - 2*abar* p'*XWH'*q
			double loss_neg = 0;
			// compute abar^2*sum(p)*sum(q) - 2*abar* p'*XWH'*q
			if(abar) { // {{{
				double sump = 0;
				if(uniform_p)
					sump = uniform_p*m;
				else
					for(size_t i = 0; i < m; i++)
						sump += p.data()[i];
				double sumq = 0;
				if(uniform_q)
					sumq = uniform_q*n;
				else
					for(size_t j = 0; j < n; j++)
						sumq += q.data()[j];
				loss_neg += abar * abar * sump * sumq;
				//loss_neg -= 2.0*abar*do_dot_product(dvec_t(p.transpose()*XW), dvec_t(bk));
				//loss_neg -= 2.0*abar*do_dot_product(bd, W*bk);
				delta = do_dot_product(W, dk_buf);
				loss_neg -= 2.0*delta;
			} // }}}
			// compute <XW, diag(p) XWH'H>
			if(uniform_p) {
				if(prob->X->is_dense() || prob->X->is_sparse())
					loss_neg += uniform_p*do_dot_product(XW, XWM);
				else
					loss_neg += uniform_p*do_dot_product(W,XWM);
				do_scale(uniform_p, XWM);
			} else {
				for(size_t i = 0; i < m; i ++) {
					val_type &pi = p.data()[i];
					for(size_t t = 0; t < k; t++)
						XWM.at(i,t) *= pi;
				}
				if(prob->X->is_dense() || prob->X->is_sparse())
					loss_neg += do_dot_product(XW, XWM);
				else
					loss_neg += do_dot_product(W,XWM);
			}
			loss_neg *= 0.5;

			// compute regularization
			double reg = 0.5*do_dot_product(W,W);
			double reg_graph = 0;
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(L, W, dk_buf2);
				reg_graph = 0.5*do_dot_product(W, dk_buf2);
				//reg_graph = 0.5*trace_dmat_T_smat_dmat(W, L, W);
			}

			double obj = loss_pos + loss_neg + reg*param->lambda + reg_graph*param->lambda_graph;
			cur_obj = obj;
			return obj;
		} // }}}
		virtual double line_search(val_type *s, val_type *w, val_type *g, double step_size, double *new_obj, bool do_update) { // {{{
			int w_size = get_nr_variable();
			int max_num_linesearch = 100;
			dmat_t &M=HTH, &B = mk_buf0, &Mbar = mk_buf1, &Lbar = dk_buf2;
			dmat_t &dB = mk_buf2, &dMbar = mk_buf3, &Bnew = mk_buf4, &dLbar = dk_buf3;
			dvec_t &d1loss = nnzY_buf1;
			W.buf = w; S.buf = s; G.buf = g;

			// backward line search parameters
			const double eta = 0.01;
			const double beta = 0.5;

			// compute dB and dMbar
			if(prob->X->is_dense()) {
				dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X, S, dB);
				dmat_x_dmat(dB, M, dMbar);
			} else if(prob->X->is_identity()) {
				dmat_x_dmat(S, M, dMbar);
				dB.assign(S);
			} else if(prob->X->is_sparse()) {
				smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, S, dB);
				dmat_x_dmat(dB, M, dMbar);
			}
			if(uniform_p)
				do_scale(uniform_p, dMbar);
			else {
				for(size_t i = 0; i < m; i++) {
					val_type &pi = p.data()[i];
					for(size_t t = 0; t < k; t++)
						dMbar.at(i,t) *= pi;
				}
			}
			double ddelta = 0;
			if(abar) ddelta = do_dot_product(S, dk_buf);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(L, S, dLbar);
			}
		   	double val1 = 2*do_dot_product(B, dMbar) - 2.0*ddelta + 2*param->lambda*do_dot_product(W, S);
			double val2 = do_dot_product(dB, dMbar) + param->lambda*do_dot_product(S, S);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				val1 += 2*param->lambda_graph*do_dot_product(W, dLbar);
				val2 += param->lambda_graph*do_dot_product(S, dLbar);
			}
			val1 *= 0.5;
			val2 *= 0.5;
			double c = do_dot_product(G, S);

			int num_linesearch = 0;
			for(num_linesearch=0; num_linesearch < max_num_linesearch; num_linesearch++) {

				Bnew.assign(B);
				do_axpy(step_size, dB, Bnew);
				partial_dmat_x_dmat(Y, ROWMAJOR, Bnew, H.transpose(), d1loss);
				double loss_pos_new = compute_loss_pos(d1loss);

				*new_obj = cur_obj - loss_pos + loss_pos_new + step_size*val1 + step_size*step_size*val2;
				if(*new_obj < (cur_obj+step_size*eta*c)) {
					if(do_update) {
						do_axpy(step_size, S, W);
						B.assign(Bnew);
						do_axpy(step_size, dMbar, Mbar);
						delta = delta + step_size*ddelta;
						loss_pos = loss_pos_new;
						cur_obj = *new_obj;
						if(param->lambda_graph && prob->L != NULL)
							do_axpy(step_size, dLbar, Lbar);
					}
					break;
				} else
					step_size *= beta;
			}

			if (num_linesearch >= max_num_linesearch) {
				step_size = 0;
				*new_obj = cur_obj;
			}

			return step_size;
		} // }}}
		virtual double line_search_slow(val_type *s, val_type *w, val_type *g, double step_size, double *new_obj, bool do_update) { // {{{
			int w_size = get_nr_variable();
			int max_num_linesearch = 1000;
			dmat_t &M=HTH, &B = mk_buf0, &Wnew = mk_buf4;
			W.buf = w; S.buf = s; G.buf = g;

			// backward line search parameters
			const double eta = 0.01;
			const double beta = 0.5;

			double c = do_dot_product(G, S);

			printf("rofu>>>c %g\n", c);
			double tmp_loss = loss_pos;
			double tmp_obj = cur_obj;

			int num_linesearch = 0;
			for(num_linesearch=0; num_linesearch < max_num_linesearch; num_linesearch++) {

				Wnew.assign(W);
				do_axpy(step_size, S, Wnew);

				*new_obj = fun(Wnew.data());
				printf("line-search %d loss_pos %g loss_pos_new %g obj %g new-obj %g\n", num_linesearch, tmp_loss, loss_pos, tmp_obj,  *new_obj);
				if(*new_obj < (tmp_obj+step_size*eta*c)) {
					break;
				} else
					step_size *= beta;
			}

			if (num_linesearch >= max_num_linesearch) {
				step_size = 0;
				*new_obj = cur_obj;
			}

			return step_size;
		} // }}}
		void grad(val_type *w, val_type *g) { // {{{
			// X'D^+H + X'diag(p)XWH'H - \bd \bk' + \lambda_w W
			W.buf = w; G.buf = g;
			dmat_t &diag_p_XWM = mk_buf1, &Lbar = dk_buf2;
			Dplus.val_t = nnzY_buf1.data();
			dmat_t &DH = mk_buf2;
			smat_x_dmat(Dplus, H, DH);
			do_axpy(1.0, diag_p_XWM, DH); // D^+*H + diag(p)*X*W*H'*diag(q)*H in mk_buf2

			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X.transpose(), mk_buf2, G);
			} else if(prob->X->is_identity()) {
				memcpy(G.data(), mk_buf2.data(), sizeof(val_type)*m*k);
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X.transpose(), mk_buf2, G);
			}
			if(abar != 0.0)
				do_axpy(-1.0, dk_buf, G);
			do_axpy(param->lambda, W, G);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				do_axpy(param->lambda_graph, Lbar, G);
				//smat_x_dmat(param->lambda_graph, L, W, 1.0, G, G);
			}
		} // }}}
		void Hv(val_type *s, val_type *Hs) { // {{{
			S.buf = s; HS.buf = Hs;
			dvec_t &uplus = nnzY_buf1;
			dvec_t &d2loss = nnzY_buf2;
			Uplus.val_t = uplus.data();

			// compute: XS stored in mk_buf2
			dmat_t &XS = prob->X->is_identity()? S : mk_buf2;
			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X, S, XS);
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, S, XS);
			}

			// compute: XSHTH = X*S*H'*diag(q)*H stored in mk_buf3
			dmat_t &XSHTH = mk_buf3;
			dmat_x_dmat(XS, HTH, XSHTH);

			// compute: Uij = d2loss_ij * XS(i,:) * H'(:,j)
			partial_dmat_x_dmat(Y,ROWMAJOR,XS,H.transpose(), uplus);
#pragma omp parallel for schedule(static)
			for(size_t idx = 0; idx < Y.nnz; idx++)
				uplus[idx] *= d2loss[idx];
			//}

			// compute: mk_buf2 = U*H+diag(p)*X*S*H'*H
			dmat_t &UH = mk_buf2; // utlizing the space as XS is no longer required.
			//if(uniform_p == 0 || uniform_q == 0)
			smat_x_dmat(Uplus, H, UH);
			//else {
			//smat_x_dmat(Uplus, H, UH); do_scale(1.0-uniform_p*uniform_q, UH);
			//smat_x_dmat(1.0-uniform_p*uniform_q, Uplus, H, 0.0, UH, UH);
			//}
			if(uniform_p != 0.0) {
				do_axpy(uniform_p, XSHTH, UH);
			} else {
#pragma omp parallel for schedule(static)
				for(size_t i = 0; i < m; i ++) {
					val_type &pi = p.data()[i];
					for(size_t t = 0; t < k; t++)
						XSHTH.at(i,t) *= pi;
				}
				do_axpy(1.0, XSHTH, UH);
			}

			// compute: X'*mk_buf2
			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X.transpose(), mk_buf2, HS);
			} else if(prob->X->is_identity()){
				memcpy(Hs, mk_buf2.data(), sizeof(val_type)*m*k);
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X.transpose(), mk_buf2, HS);
			}
			do_axpy(param->lambda, S, HS);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(param->lambda_graph, L, S, 1.0, HS, HS);
			}
		} // }}}
}; // }}}

// gr/l2r + squared-L2 loss + one-class observation + general X
class l2r_ls_oY_gX : public l2r_erm_oY_gX {  // {{{
	protected:
		dvec_t nnzY_buf3;
		// loss on observed entries loss_{ij)(Y_{ij},  xi' W hj)
		double compute_loss_pos(const dvec_t &pred_val) { // {{{
			dvec_t &d1loss = nnzY_buf1, &d2loss = nnzY_buf2;
			double loss_pos_tmp = 0;
#pragma omp parallel for schedule(dynamic,64) reduction(+:loss_pos_tmp)
			for(size_t i = 0; i < Y.rows; i++) {
				double loss_pos_local = 0.0;
				for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
					size_t j = Y.col_idx[idx];
					const val_type& Yij = Y.val_t[idx];
					val_type &xiTWhj = d1loss[idx];
					double tmp1 = (xiTWhj-Yij), tmp2 = (xiTWhj-abar);
					loss_pos_local += tmp1*tmp1 - get_C(i,j)*tmp2*tmp2;
					// compute 1st/2nd order differential info
					d1loss[idx] = tmp1 - get_C(i,j)*tmp2;
					d2loss[idx] = 1.0 - get_C(i,j);
				}
				loss_pos_tmp += loss_pos_local;
			}
			return 0.5*loss_pos_tmp;
		} // }}}
	public:
		l2r_ls_oY_gX(const prob_t* prob, const param_t* param): l2r_erm_oY_gX(prob, param), nnzY_buf3(prob->Y->nnz) {}
		/*
		virtual double line_search(val_type *s, val_type *w, val_type *g, double step_size, double *new_obj, bool do_update) { // {{{
			int w_size = get_nr_variable();
			int max_num_linesearch = 100;
			dmat_t &M=HTH, &B = mk_buf0, &Mbar = mk_buf1, &Lbar = dk_buf2;
			dmat_t &dB = mk_buf2, &dMbar = mk_buf3, &Bnew = mk_buf4, &dLbar = dk_buf3;
			dvec_t &d1loss = nnzY_buf1, &d2loss = nnzY_buf2;
			W.buf = w; S.buf = s; G.buf = g;

			// backward line search parameters
			const double eta = 0.01;
			const double beta = 0.5;

			// compute dB and dMbar
			if(prob->X->is_dense()) {
				dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X, S, dB);
				dmat_x_dmat(dB, M, dMbar);
			} else if(prob->X->is_identity()) {
				dmat_x_dmat(S, M, dMbar);
				dB.assign(S);
			} else if(prob->X->is_sparse()) {
				smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, S, dB);
				dmat_x_dmat(dB, M, dMbar);
			}
			if(uniform_p)
				do_scale(uniform_p, dMbar);
			else {
				for(size_t i = 0; i < m; i++) {
					val_type &pi = p.data()[i];
					for(size_t t = 0; t < k; t++)
						dMbar.at(i,t) *= pi;
				}
			}
			double ddelta = 0;
			if(abar) ddelta = do_dot_product(S, dk_buf);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(L, S, dLbar);
			}
		   	double val1 = 2*do_dot_product(B, dMbar) - 2.0*ddelta + 2*param->lambda*do_dot_product(W, S);
			double val2 = do_dot_product(dB, dMbar) + param->lambda*do_dot_product(S, S);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				val1 += 2*param->lambda_graph*do_dot_product(W, dLbar);
				val2 += param->lambda_graph*do_dot_product(S, dLbar);
			}
			// for loss_pos part
			partial_dmat_x_dmat(Y, ROWMAJOR, dB, H.transpose(), nnzY_buf3);
			val1 += 2.0*do_dot_product(nnzY_buf3, d1loss);
#pragma omp parallel for schedule(static) reduction(+:val2)
			for(size_t idx = 0; idx < Y.nnz; idx++)
				val2 += nnzY_buf3[idx]*nnzY_buf3[idx]*d2loss[idx];
			val1 *= 0.5;
			val2 *= 0.5;
			double c = do_dot_product(G, S);

			int num_linesearch = 0;
			for(num_linesearch=0; num_linesearch < max_num_linesearch; num_linesearch++) {

				*new_obj = cur_obj + step_size*val1 + step_size*step_size*val2;
				if(*new_obj < (cur_obj+step_size*eta*c)) {
					if(do_update) {
						do_axpy(step_size, S, W);
						do_axpy(step_size, dB, B);
						do_axpy(step_size, dMbar, Mbar);
						delta = delta + step_size*ddelta;
						cur_obj = *new_obj;
						if(param->lambda_graph && prob->L != NULL)
							do_axpy(step_size, dLbar, Lbar);
#pragma omp parallel for schedule(static)
						for(size_t idx = 0; idx < Y.nnz; idx++)
							d1loss[idx] += d2loss[idx]*step_size*nnzY_buf3[idx];
					}
					break;
				} else
					step_size *= beta;
			}

			if (num_linesearch >= max_num_linesearch) {
				step_size = 0;
				*new_obj = cur_obj;
			}

			return step_size;
		} // }}}
*/
}; // }}}

// gr/l2r + logistic loss + one-class observation + general X
class l2r_lr_oY_gX : public l2r_erm_oY_gX {  // {{{
	protected:
		// loss on observed entries loss_{ij)(Y_{ij},  xi' W hj)
		double compute_loss_pos(const dvec_t &pred_val) { // {{{
			dvec_t &d1loss = nnzY_buf1, &d2loss = nnzY_buf2;
			double loss_pos_tmp = 0;
#pragma omp parallel for schedule(dynamic,64) reduction(+:loss_pos_tmp)
			for(size_t i = 0; i < Y.rows; i++) {
				double loss_pos_local = 0.0;
				for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
					size_t j = Y.col_idx[idx];
					const val_type& Yij = Y.val_t[idx];
					const val_type &xiTWhj = pred_val[idx];
					double tmp1 = 1./(1+exp(-Yij*xiTWhj)), tmp2 = (xiTWhj-abar);
					if(Yij*xiTWhj >= 0)
						loss_pos_local += log(1 + exp(-Yij*xiTWhj));
					else
						loss_pos_local += -Yij*xiTWhj + log(1 + exp(Yij*xiTWhj));
					loss_pos_local -= 0.5*get_C(i,j)*tmp2*tmp2;
					// compute 1st/2nd order differential info
					d1loss[idx] = (tmp1-1)*Yij - get_C(i,j)*tmp2;
					d2loss[idx] = tmp1*(1-tmp1)- get_C(i,j);
				}
				loss_pos_tmp += loss_pos_local;
			}
			return loss_pos_tmp;
		} // }}}
	public:
		l2r_lr_oY_gX(const prob_t* prob, const param_t* param): l2r_erm_oY_gX(prob, param) {}
}; // }}}

// gr/l2r + logistic loss + partial observation + general X
class l2r_lr_pY_gX : public function<val_type> {  // {{{
	protected:
		// {{{
		typedef bilinear_prob_t prob_t;
		typedef bilinear_param_t param_t;
		const prob_t* prob;
		const param_t* param;
		const smat_t &Y;
		const dmat_t &H;
		smat_t Dplus;      // view of Y
		smat_t Uplus;      // view of Y
		dvec_t nnzY_buf1;  // Y.nnz: sparse (XWH')^+ or D^+ or U^+
		dvec_t nnzY_buf2;  // Y.nnz: sparse dot{D}^+ // required for other loss
		dmat_t p;          // m*1 : could be param->rho*ones(m,1)
		dmat_t q;          // n*1 : could be ones(n,1)
		dmat_t bd;         // d*1 : 2*abar*X'*p
		dmat_t bk;         // k*1 : 2*H'*q
		dmat_t HTH;        // k*k : M = H' * diag(q) * H
		dmat_t mk_buf1;    // m*k : X*W or X*S
		dmat_t mk_buf2;    // m*k : X*W*M or X*S*M
		dmat_t dk_buf;     // d*k : W*M or S*M
		//dmat_t dd_buf;     // d*d store X'*diag(p)*X
		dmat_t nk_buf;     // n*k : store diag(q) * H
		dmat_t W, G, S, HS;// d*k view
		const size_t &m, &d, &n, &k;
		const double &rho, &abar;
		double uniform_p, uniform_q;
		bool maintain_XTX;
		// }}}
		double get_C(size_t i, size_t j) { // {{{
			return p.data()[i]*q.data()[j];
			if(uniform_p && uniform_q) return uniform_p*uniform_q;
			if(uniform_p) return uniform_p * q.data()[j];
			if(uniform_q) return p.data()[i] * uniform_q;
		} // }}}
	public:
		l2r_lr_pY_gX(const prob_t* prob, const param_t* param): prob(prob), param(param), Y(*(prob->Y)), H(*(prob->H)), m(prob->m), d(prob->d), n(prob->n), k(prob->k), rho(param->rho), abar(param->abar) { // {{{
			assert(H.is_rowmajor());
			assert(!prob->X->is_dense() || prob->X->get_dense().is_rowmajor());
			maintain_XTX = false; // Check here
			Dplus = Y.get_view();
			Uplus = Y.get_view();
			nnzY_buf1 = dvec_t(Y.nnz);
			nnzY_buf2 = dvec_t(Y.nnz);
			if(rho != 0.0) {
				p = dmat_t(m,1,COLMAJOR);
				q = dmat_t(n,1,COLMAJOR);
				for(size_t i = 0; i < m; i++)
					p.at(i,0) = (val_type)rho;
				for(size_t j = 0; j < n; j++)
					q.at(j,0) = (val_type)1.0;
				if(abar != 0.0) {
					bd = dmat_t(d, 1, COLMAJOR);
					bk = dmat_t(k, 1, COLMAJOR);
				}
				// determine uniform_p/uniform_q{{{
				uniform_q = q.at(0,0);
				for(size_t j = 1; j < n; j++) {
					if(q.at(j,0) != uniform_q) {
						uniform_q = 0.0;
						break;
					}
				}
				if(uniform_q) {
					uniform_p = (p.at(0,0)*=uniform_q);
					for(size_t i = 1; i < m; i++) {
						p.at(i,0)*=uniform_q;
						if(p.at(i,0) != uniform_p) {
							uniform_p = 0.0;
						}
					}
					uniform_q = 1.0;
					for(size_t j = 0; j < n; j++)
						q.at(j,0) = 1.0;
				} else {
					uniform_p = p.at(0,0);
					for(int i = 1; i < m; i++) {
						if(p.at(i,0) != uniform_p) {
							uniform_p = 0.0;
							break;
						}
					}
				}
				// }}}
				printf("uniform_p %g uniform_q %g\n", uniform_p, uniform_q);
			} else {
				p = dmat_t(m,1,COLMAJOR);
				q = dmat_t(n,1,COLMAJOR);
				for(size_t i = 0; i < m; i++)
					p.at(i,0) = 0;
				for(size_t j = 0; j < n; j++)
					q.at(j,0) = 0;
			}
			HTH = dmat_t(k, k, ROWMAJOR);
			if(prob->X->is_dense()) {
				mk_buf1 = dmat_t(m, k, ROWMAJOR);
				mk_buf2 = dmat_t(m, k, ROWMAJOR);
			} else if(prob->X->is_identity()) {
				mk_buf1 = dmat_t(m, k, ROWMAJOR);
				mk_buf2 = dmat_t(m, k, ROWMAJOR);
			} else if(prob->X->is_sparse()) {
				mk_buf1 = dmat_t(m, k, ROWMAJOR);
				mk_buf2 = dmat_t(m, k, ROWMAJOR);
			}
			dk_buf = dmat_t(d, k, ROWMAJOR);
			nk_buf = dmat_t(n, k, ROWMAJOR);
			W = G = S = HS = dk_buf.get_view();
		} // }}}
		void init() { // {{{
			if(rho!=0) {  // otherwise HTH = zeros(k,k)
				/*
				   if(uniform_q) {
				   dmat_x_dmat(H.transpose(), H, HTH);
				   if(uniform_q != 1.0)
				   do_scale(uniform_q, HTH);
				//				printf(" BH=> %p buf %p %g\n", &H, H.buf, norm(HTH));
				} else {
				for(size_t j = 0; j < n; j++) {
				double qj = q.at(j,0);
				for(size_t t = 0; t < k; t++)
				nk_buf.at(j,t) = qj * H.at(j,t);
				}
				dmat_x_dmat(H.transpose(), nk_buf, HTH);
				}
				if(abar != 0.0) {
				dmat_x_dmat(H.transpose(), q, bk);
				if(prob->X->is_dense()) {
				dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X.transpose(), p, bd);
				} else if(prob->X->is_identity()) {
				bd = p;
				} else if(prob->X->is_sparse()) {
				smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X.transpose(), p, bd);
				}
				dmat_x_dmat(bd, bk.transpose(), dk_buf);
				do_scale(abar, dk_buf);
				}
				*/
			}
		} // }}}
		int get_nr_variable(void) {return (int)(d*k);}
		// post conditions after fun(w):
		//     mk_buf2 = diag(p)*XWHTH = diag(p)*XWM
		//     nnzY_buf1 = d1loss
		//     nnzY_buf2 = d2loss
		double fun(val_type *w) { // {{{
			dmat_t &M=HTH, &XW = mk_buf1, &XWM = mk_buf2;
			dvec_t &d1loss = nnzY_buf1, &d2loss = nnzY_buf2;
			const dvec_t y(Y.nnz, Y.val_t);
			W.buf = w;
			if(prob->X->is_dense()) {
				dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X, W, XW);
				dmat_x_dmat(XW, M, XWM);
				partial_dmat_x_dmat(Y, ROWMAJOR, XW, H.transpose(), d1loss);
			} else if(prob->X->is_identity()) {
				dmat_x_dmat(W, M, XWM);
				XW.assign(W);
				partial_dmat_x_dmat(Y, ROWMAJOR, W, H.transpose(), d1loss);
			} else if(prob->X->is_sparse()) {
				smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, W, XW);
				dmat_x_dmat(XW, M, XWM);
				partial_dmat_x_dmat(Y, ROWMAJOR, XW, H.transpose(), d1loss);
			}

			// loss on observed entries loss_{ij)(Y_{ij},  xi' W hj)
			double loss_pos = 0;
#pragma omp parallel for schedule(dynamic,64) reduction(+:loss_pos)
			for(size_t i = 0; i < Y.rows; i++) {
				double loss_pos_local = 0.0;
				for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
					size_t j = Y.col_idx[idx];
					const val_type& Yij = Y.val_t[idx];
					val_type rho_local = Yij > 0? 1.0 : rho;
					val_type &xiTWhj = d1loss[idx];
					double tmp1 = 1./(1+exp(-Yij*xiTWhj));// tmp2 = (xiTWhj-abar);
					if(Yij*xiTWhj >= 0)
						loss_pos_local += rho_local*(log(1 + exp(-Yij*xiTWhj)));
					else
						loss_pos_local += rho_local*(-Yij*xiTWhj + log(1 + exp(Yij*xiTWhj)));
					// loss_pos_local -= 0.5*get_C(i,j)*tmp2*tmp2;
					// compute 1st/2nd order differential info
					d1loss[idx] = rho_local*(tmp1-1)*Yij; // - get_C(i,j)*tmp2;
					d2loss[idx] = rho_local*tmp1*(1-tmp1); //- get_C(i,j);
				}
				loss_pos += loss_pos_local;
			}

			// loss on all entries (||diag(p)^1/2*(abar*ones(m,n) - XWH')*diag(q)^1/2||^2)
			//  = abar^2*sum(p)*sum(q) + <XW, diag(p) XWH'H> - 2*abar* p'*XWH'*q
			double loss_neg = 0;
			/*
			// compute abar^2*sum(p)*sum(q) - 2*abar* p'*XWH'*q
			if(abar) { // {{{
			double sump = 0;
			if(uniform_p)
			sump = uniform_p*m;
			else
			for(size_t i = 0; i < m; i++)
			sump += p.data()[i];
			double sumq = 0;
			if(uniform_q)
			sumq = uniform_q*n;
			else
			for(size_t j = 0; j < n; j++)
			sumq += q.data()[j];
			loss_neg += abar * abar * sump * sumq;
			//loss_neg -= 2.0*abar*do_dot_product(dvec_t(p.transpose()*XW), dvec_t(bk));
			loss_neg -= 2.0*abar*do_dot_product(bd, W*bk);
			} // }}}
			// compute <XW, diag(p) XWH'H>
			if(uniform_p) {
			if(prob->X->is_dense() || prob->X->is_sparse())
			loss_neg += uniform_p*do_dot_product(XW, XWM);
			else
			loss_neg += uniform_p*do_dot_product(W,XWM);
			} else {
			for(size_t i = 0; i < m; i ++) {
			val_type &pi = p.data()[i];
			for(size_t t = 0; t < k; t++)
			XWM.at(i,t) *= pi;
			}
			if(prob->X->is_dense() || prob->X->is_sparse())
			loss_neg += do_dot_product(XW, XWM);
			else
			loss_neg += do_dot_product(W,XWM);
			}
			loss_neg *= 0.5;
			*/

			// compute regularization
			double reg = 0.5*do_dot_product(W,W);
			double reg_graph = 0;
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				reg_graph = 0.5*trace_dmat_T_smat_dmat(W, L, W);
			}

			//printf("> loss_pos %.10g loss_neg %.10g sum %.10g", loss_pos, loss_neg, loss_pos + loss_neg);
			double obj = loss_pos + loss_neg + reg*param->lambda + reg_graph*param->lambda_graph;
			return obj;
		} // }}}
		void grad(val_type *w, val_type *g) { // {{{
			// X'D^+H + X'diag(p)XWH'H - \bd \bk' + \lambda_w W
			W.buf = w; G.buf = g;
			dmat_t &diag_p_XWM = mk_buf2;
			Dplus.val_t = nnzY_buf1.data();
			dmat_t &DH = mk_buf1;
			smat_x_dmat(Dplus, H, DH);
			/*
			   if(uniform_p)
			   do_axpy(uniform_p, diag_p_XWM, DH); // D^+*H + diag(p)*X*W*H'*diag(q)*H in mk_buf1
			   else
			   do_axpy(1.0, diag_p_XWM, DH); // D^+*H + diag(p)*X*W*H'*diag(q)*H in mk_buf1
			   */

			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X.transpose(), mk_buf1, G);
			} else if(prob->X->is_identity()) {
				memcpy(G.data(), mk_buf1.data(), sizeof(val_type)*m*k);
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X.transpose(), mk_buf1, G);
			}
			/*
			   if(abar != 0.0)
			   do_axpy(-1.0, dk_buf, G);
			   */
			do_axpy(param->lambda, W, G);
			if(param->lambda_graph != 0 && prob->L != NULL) {
				const smat_t &L = *(prob->L);
				smat_x_dmat(param->lambda_graph, L, W, 1.0, G, G);
			}
		} // }}}
		void Hv(val_type *s, val_type *Hs) { // {{{
			S.buf = s; HS.buf = Hs;
			dvec_t &uplus = nnzY_buf1;
			dvec_t &d2loss = nnzY_buf2;
			Uplus.val_t = uplus.data();

			// compute: XS stored in mk_buf1
			dmat_t &XS = prob->X->is_identity()? S : mk_buf1;
			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X, S, XS);
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X, S, XS);
			}

			// compute: XSHTH = X*S*H'*diag(q)*H stored in mk_buf2
			//dmat_t &XSHTH = mk_buf2;
			//dmat_x_dmat(XS, HTH, XSHTH);

			// compute: Uij = d2loss_ij * XS(i,:) * H'(:,j)
			partial_dmat_x_dmat(Y,ROWMAJOR,XS,H.transpose(), uplus);
			//if(uniform_p == 0 || uniform_q == 0) {
#pragma omp parallel for schedule(static)
			for(size_t idx = 0; idx < Y.nnz; idx++)
				uplus[idx] *= d2loss[idx];
			//}

			// compute: mk_buf1 = U*H+diag(p)*X*S*H'*H
			dmat_t &UH = mk_buf1; // utlizing the space as XS is no longer required.
			//if(uniform_p == 0 || uniform_q == 0)
			smat_x_dmat(Uplus, H, UH);
			//else {
			//smat_x_dmat(Uplus, H, UH); do_scale(1.0-uniform_p*uniform_q, UH);
			//smat_x_dmat(1.0-uniform_p*uniform_q, Uplus, H, 0.0, UH, UH);
			//}
			/*
			   if(uniform_p != 0.0) {
			   do_axpy(uniform_p, XSHTH, UH);
			   } else {
#pragma omp parallel for schedule(static)
for(size_t i = 0; i < m; i ++) {
val_type &pi = p.data()[i];
for(size_t t = 0; t < k; t++)
XSHTH.at(i,t) *= pi;
}
do_axpy(1.0, XSHTH, UH);
}
*/

			// compute: X'*mk_buf1
			if(prob->X->is_dense()) {
				const dmat_t &X = prob->X->get_dense();
				dmat_x_dmat(X.transpose(), mk_buf1, HS);
			} else if(prob->X->is_identity()){
				memcpy(Hs, mk_buf1.data(), sizeof(val_type)*m*k);
			} else if(prob->X->is_sparse()) {
				const smat_t &X = prob->X->get_sparse();
				smat_x_dmat(X.transpose(), mk_buf1, HS);
			}
do_axpy(param->lambda, S, HS);
if(param->lambda_graph != 0 && prob->L != NULL) {
	const smat_t &L = *(prob->L);
	smat_x_dmat(param->lambda_graph, L, S, 1.0, HS, HS);
}
} // }}}
}; // }}}

class o2r_ls_fY_gX : public l2r_ls_fY_gX { // {{{
	private:
		dmat_t HTH_lambda;
	public :
		o2r_ls_fY_gX(const prob_t* prob, const param_t *param): l2r_ls_fY_gX(prob, param) { // {{{
			HTH_lambda = HTH.get_view();
		} // }}}
		void init() { // {{{
			l2r_ls_fY_gX::init();
			for(size_t t = 0; t < k; t++)
				HTH_lambda.at(t,t) += param->lambda;
		} // }}}
		double fun(val_type* w) { // {{{
			double obj = trYTY;
			dmat_t W(d, k, w, ROWMAJOR);
			dmat_t &XTXW = dk_buf, &WTXTXW = kk_buf;
			if(prob->X->is_dense()) {
				dmat_x_dmat(XTX, W, XTXW);
				dmat_x_dmat(W.transpose(), XTXW, WTXTXW);
			} else if (prob->X->is_identity()) {
				dmat_x_dmat(W.transpose(), W, WTXTXW);
			}
			obj += do_dot_product(WTXTXW.data(), HTH.data(), k*k);
			obj -= 2.0*do_dot_product(XTYH.data(), W.data(), d*k);
			obj *= 0.5;
			return obj;
		} // }}}
		void grad(val_type *w, val_type *g) { // {{{
			W.buf = w; G.buf = g;
			// assume fun(w) just being called
			dmat_t &XTXW = prob->X->is_identity()? W: dk_buf;
			do_copy(XTYH.data(), g, d*k);
			dmat_x_dmat((val_type)(1.0), XTXW, HTH_lambda, (val_type)(-1.0), G);
		} // }}}
		void Hv(val_type *s, val_type *Hs) { // {{{
			S.buf = s; HS.buf = Hs;
			if(prob->X->is_dense()) {
				dmat_t &XTXS = dk_buf;
				dmat_x_dmat(XTX, S, XTXS);
				dmat_x_dmat(XTXS, HTH_lambda, HS);
			} else if (prob->X->is_identity()) {
				dmat_x_dmat(S, HTH_lambda, HS);
			}
		} // }}}
}; // }}}

/*
 *  Case with X = I
 *  W = argmin_{W} 0.5*|Y-WH'|^2 + 0.5*lambda*|W|^2
 *
 *  W = argmin_{W}  C * |Y - W*H'|^2 +  0.5*|W|^2
 *             where C = 1/(2*lambda)
 * */


struct l2r_ls_fY_IX_chol : public solver_t { // {{{
	const smat_t &Y;
	const dmat_t &H;
	dmat_t HTH;
	dmat_t YH;
	dmat_t kk_buf;
	double trYTY;
	val_type lambda;
	const size_t &m, &k;
	bool done_init;
	l2r_ls_fY_IX_chol(const smat_t &Y, const dmat_t &H, val_type lambda): Y(Y), H(H), HTH(), YH(), kk_buf(), trYTY(0), lambda(lambda), m(Y.rows), k(H.cols), done_init(false) { // {{{
		HTH = dmat_t(k, k, ROWMAJOR);
		YH = dmat_t(m, k, ROWMAJOR);
		kk_buf = dmat_t(k, k, ROWMAJOR);
		trYTY = do_dot_product(Y.val, Y.val, Y.nnz);
	} // }}}
	void init_prob() { // {{{
		dmat_x_dmat(H.transpose(), H, HTH);
		for(size_t t= 0; t < k; t++)
			HTH.at(t,t) += lambda;
		smat_x_dmat(Y, H, YH);
		done_init = true;
	} // }}}
	void solve(val_type *w) { // {{{
		if(!done_init) {init_prob();}
		do_copy(YH.data(), w, m*k);
		ls_solve_chol_matrix_colmajor(HTH.data(), w, k, m);
		done_init = false;
	} // }}}
	double fun(val_type *w) { // {{{
		if(!done_init) {init_prob();}
		dmat_t W(m, k, w, ROWMAJOR);
		double obj = trYTY;
		dmat_x_dmat(W.transpose(), W, kk_buf);
		obj += do_dot_product(kk_buf.data(), HTH.data(), k*k);
		obj -= 2.0*do_dot_product(w, YH.data(), m*k);
		obj *= 0.5;
		return obj;
	} // }}}
}; // }}}

struct l2r_ls_mY_IX_chol : public solver_t { // {{{
	const smat_t &Y;
	const dmat_t &H;
	std::vector<dvec_t> Hessian_set;
	val_type lambda;
	const size_t &m, &k;
	size_t nr_threads;
	l2r_ls_mY_IX_chol(const smat_t& Y, const dmat_t& H, val_type lambda): Y(Y), H(H), Hessian_set(), lambda(lambda), m(Y.rows), k(H.cols) { // {{{
		nr_threads = omp_get_max_threads();
		Hessian_set.resize(nr_threads, dvec_t(k*k));
	} // }}}
	void init_prob() {}
	void solve(val_type *w) { // {{{
#pragma omp parallel for schedule(dynamic,64)
		for(size_t i = 0; i < Y.rows; i++) {
			size_t nnz_i = Y.nnz_of_row(i);
			if(nnz_i == 0) continue;
			int tid = omp_get_thread_num(); // thread ID
			val_type *Wi = &w[i*k];
			val_type *Hessian = Hessian_set[tid].data();

			val_type *y = Wi;
			memset(Hessian, 0, sizeof(val_type)*k*k);
			memset(y, 0, sizeof(val_type)*k);
			for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
				const val_type *Hj = H[Y.col_idx[idx]].data();
				for(size_t s = 0; s < k; s++){
					y[s] += Y.val_t[idx]*Hj[s];
					for(size_t t = s; t < k; t++)
						Hessian[s*k+t] += Hj[s]*Hj[t];
				}
			}
			for(size_t s = 0; s < k; s++) {
				for(size_t t = 0; t < s; t++)
					Hessian[s*k+t] = Hessian[t*k+s];
				Hessian[s*k+s] += lambda;
			}
			ls_solve_chol(Hessian, y, k);
		}
	} // }}}
	double fun(val_type *w) { // {{{
		double loss = 0;
#pragma omp parallel for reduction(+:loss) schedule(dynamic,32)
		for(size_t i = 0; i < Y.rows; i++) {
			val_type* Wi = &w[i*k];
			for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
				double err = -Y.val_t[idx];
				const val_type* Hj = H[Y.col_idx[idx]].data();
				for(size_t s = 0; s < k; s++)
					err += Wi[s]*Hj[s];
				loss += err*err;
			}
		}
		double reg = do_dot_product(w,w,m*k);
		double obj = 0.5 * (loss + lambda*reg);
		return obj;
	} // }}}
}; // }}}

// loss on all entries 0.5*(||diag(p)^1/2*(abar*ones(m,n) - AW*BH')*diag(q)^1/2||^2)
//  = 0.5* {abar^2*sum(p)*sum(q) + <AW, diag(p) AW*BH'*diag(q)*BH> - 2*abar* p'*AW*BH'*q}
static double compute_loss_neg(const dmat_t &AW, const dmat_t &BH, double rho, double abar) { // {{{
	size_t m = AW.rows, n = BH.rows, k = AW.cols;
	double loss_neg = 0;
	// compute abar^2*sum(p)*sum(q) - 2*abar* p'*AW*BH'*q
	if(abar) { // {{{
		double sump = rho*m, sumq = ((rho!=0)?(double)n:0);
		loss_neg += abar*abar*sump*sumq;
		for(size_t t = 0; t < k; t++) {
			double tmp1 = 0, tmp2 = 0;
			if(rho!=0) {
				for(size_t i = 0; i < m; i++)
					tmp1 += AW.at(i,t);
				for(size_t j = 0; j < n; j++)
					tmp2 += BH.at(j,t);
			}
			loss_neg -= 2.0*abar*rho*tmp1*tmp2;
		}
	} // }}}
	// compute <AW, diag(p) AW*BH'*diag(q)*BH
	if(rho) loss_neg += rho*do_dot_product(AW*(BH.transpose()*BH), AW);
	return 0.5*loss_neg;
} // }}}

struct l2r_ls_oY_IX_chol : public solver_t { // {{{
	const smat_t &Y;
	const dmat_t &H;
	dmat_t rhoHTH, kk_buf;
	dvec_t abar_rhoHTe;
	std::vector<dvec_t> Hessian_set;
	val_type lambda;
	val_type rho, abar;
	const size_t &m, &k;
	size_t nr_threads;
	bool done_init;
	l2r_ls_oY_IX_chol(const smat_t& Y, const dmat_t& H, val_type lambda, val_type rho, val_type abar=0): Y(Y), H(H), rhoHTH(), kk_buf(), abar_rhoHTe(), lambda(lambda), rho(rho), abar(abar), m(Y.rows), k(H.cols), done_init(false) { // {{{
		nr_threads = omp_get_max_threads();
		rhoHTH = dmat_t(k, k, ROWMAJOR);
		kk_buf = dmat_t(k, k, ROWMAJOR);
		abar_rhoHTe = dvec_t(k);
		Hessian_set.resize(nr_threads, dvec_t(k*k));
	} // }}}
	void solve(val_type *w) { // {{{
#pragma omp parallel for schedule(dynamic,64)
		for(size_t i = 0; i < Y.rows; i++) {
			size_t nnz_i = Y.nnz_of_row(i);
			if(nnz_i == 0) continue;
			int tid = omp_get_thread_num(); // thread ID
			val_type *Wi = &w[i*k];
			val_type *Hessian = Hessian_set[tid].data();
			val_type *y = Wi;
			const val_type one_minus_rho = 1-rho;
			const val_type abar_rho = abar*rho;
			memcpy(Hessian, rhoHTH.data(), sizeof(val_type)*k*k);
			memcpy(y, abar_rhoHTe.data(), sizeof(val_type)*k);
			//memset(y, 0, sizeof(val_type)*k);

			for(size_t idx = Y.row_ptr[i]; idx != Y.row_ptr[i+1]; idx++) {
				const val_type *Hj = H[Y.col_idx[idx]].data();

				for(size_t s = 0; s < k; s++){
					y[s] += (Y.val_t[idx]-abar_rho)*Hj[s];
					for(size_t t = s; t < k; t++)
						Hessian[s*k+t] += one_minus_rho*Hj[s]*Hj[t];
				}
			}
			for(size_t s = 0; s < k; s++) {
				for(size_t t = 0; t < s; t++)
					Hessian[s*k+t] = Hessian[t*k+s];
				Hessian[s*k+s] += lambda;
			}
			ls_solve_chol(Hessian, y, k);
		}
		done_init = false;
	} // }}}
	double fun(val_type *w) { // {{{
		dmat_t W(m, k, w, ROWMAJOR);
		dvec_t omega_parts(nr_threads), zero_parts(nr_threads);
		memset(omega_parts.data(), 0, nr_threads*sizeof(val_type));
		memset(zero_parts.data(), 0, nr_threads*sizeof(val_type));
#pragma omp parallel for schedule(dynamic,64)
		for(size_t c = 0; c < Y.cols; c++) {
			val_type omega_part = val_type(0.0), zero_part = val_type(0.0);

			const dvec_t &Hj = H[c];
			for(size_t idx = Y.col_ptr[c]; idx != Y.col_ptr[c+1]; ++idx) {
				size_t r = Y.row_idx[idx];
				val_type sum = 0.0;
				const dvec_t &Wi = W[r];
				for(size_t t = 0; t < k; ++t)
					sum += Wi[t] * Hj[t];
				zero_part -= (sum-abar)*(sum-abar);
				sum -= Y.val[idx];
				omega_part += sum*sum;
			}
			omega_parts[omp_get_thread_num()] += omega_part;
			zero_parts[omp_get_thread_num()] += zero_part;
		}
		val_type omega_part = val_type(0.0), zero_part = val_type(0.0);
		for(size_t tid = 0; tid < nr_threads; tid++) {
			omega_part += omega_parts[tid];
			zero_part += zero_parts[tid];
		}

		//dmat_t &WTW = kk_buf;
		//dmat_x_dmat(W.transpose(), W, WTW);
		//zero_part = zero_part*rho + do_dot_product(WTW.data(), rhoHTH.data(), k*k);
		zero_part = zero_part*rho + 2.0*compute_loss_neg(W, H, rho, abar);

		double loss = (double)omega_part + (double)zero_part;
		double reg = do_dot_product(w,w,m*k);
		double obj = 0.5 * (loss + lambda*reg);
		return obj;
	} // }}}
	void init_prob() { // {{{
		printf("hh init_prob\n");
		dmat_x_dmat(H.transpose(), H, rhoHTH);
		for(size_t idx = 0; idx < k*k; idx++)
			rhoHTH.buf[idx] *= rho;
#pragma omp parallel for schedule(static)
		for(size_t t = 0; t < k; t++) {
			val_type sum = 0.0;
			for(size_t c = 0; c < Y.cols; c++)
				sum += H.at(c,t);
			abar_rhoHTe[t] = abar*rho*sum;
		}
	} // }}}
}; // }}}

leml_solver::leml_solver(bilinear_prob_t *prob, bilinear_param_t *param): prob(prob), param(param), fun_obj(NULL), tron_obj(NULL), solver_obj(NULL), done_init(false){ // {{{
	printf("identity %d use_chol %d\n", prob->X->is_identity(), param->use_chol);
	if(prob->X->is_identity() && param->use_chol && prob->L == NULL) {
		// Cholesky Solver for Identity case
		switch(param->solver_type) {
			case ALS:               // missing
				solver_obj = new l2r_ls_mY_IX_chol(*(prob->Y), *(prob->H), param->lambda);
				break;
			case PU_ALS:
				if (param->rho < 0) // full
					solver_obj = new l2r_ls_fY_IX_chol(*(prob->Y), *(prob->H), param->lambda);
				else                // one-class
					solver_obj = new l2r_ls_oY_IX_chol(*(prob->Y), *(prob->H), param->lambda, param->rho, param->abar);
				break;
			case LR_ALS:
				fun_obj = new l2r_lr_pY_gX(prob, param);
				break;
			case LR_PU_ALS:
				fun_obj = new l2r_lr_oY_gX(prob, param);
				break;
			default :
				fprintf(stderr, "Solver not supported\n");
				break;
		}
	} else {
		switch(param->solver_type) {
			case ALS:              // missing
				fun_obj = new l2r_ls_mY_gX(prob, param);
				break;
			case PU_ALS:
				if(param->rho < 0) // full
					fun_obj = new l2r_ls_fY_gX(prob, param);
				else               // one-class
					fun_obj = new l2r_ls_oY_gX(prob, param);
				break;
			case LR_ALS:
				fun_obj = new l2r_lr_pY_gX(prob, param);
				break;
			case LR_PU_ALS:
				fun_obj = new l2r_lr_oY_gX(prob, param);
				break;
			default :
				fprintf(stderr, "Solver not supported\n");
				break;
		}
		fflush(stdout);
		int max_cg_iter = param->max_cg_iter;
		if(max_cg_iter >= fun_obj->get_nr_variable())
			max_cg_iter = fun_obj->get_nr_variable();
		printf("max_cg_iter %d\n", max_cg_iter);
		tron_obj = new TRON<val_type>(fun_obj, param->eps, param->eps_cg, param->max_tron_iter, max_cg_iter);
		tron_obj->set_print_string(get_print_fun(param));
	}
}; // }}}
