import math
import sys
import numpy as np
import matplotlib.pyplot as plt

core = [1, 2, 4, 6, 8, 10, 12]
labelSize = 30
axeSize = 24
markerSize = 13.0
lineSize = 5

def get_logtime(logfile, core):
	tmpXv = 0
	tmpXTv = 0
	tmpTrain = 0
	fp = open(logfile, 'r')
	nowcore = 0
	flag = 0
	for line in fp.readlines():
		if line.find('cores') != -1:
			if flag == 1:
				break
			nowcore = int(line.split(' ')[2])
			if nowcore == core:
				flag = 1
		elif line.find('Xvtime') != -1 and flag==1:
			line = line.split(' ')
			tmpXv += float(line[1][:-1])
			tmpXTv += float(line[-1])
		elif line.find('train total runtime') != -1 and flag==1:
			tmpTrain += float(line.split(' ')[-1])
	fp.close()
	return tmpXv, tmpXTv, tmpTrain

def plot_subfig(subplt, mkl, omp, omp1, rsb, rsbXT):
	npcore = np.array(core)
	line_mkl = np.array(mkl)
	line_omp = np.array(omp)
	line_rsb = np.array(rsb)
	line_rsbXT = np.array(rsbXT)
	plt.plot(npcore, line_mkl, 'r--s', label='MKL', ms=markerSize, linewidth=lineSize)
	plt.plot(npcore, line_rsb, 'g--^', label='RSB', ms=markerSize, linewidth=lineSize)
	if subplt==0:
		plt.plot(npcore, line_omp1, 'm--D', label='OpenMP', ms=markerSize, linewidth=lineSize)
	if subplt==2:
		plt.plot(npcore, line_omp1, 'm--D', label='OpenMP-array', ms=markerSize, linewidth=lineSize)
	if subplt==3:
		plt.plot(npcore, line_omp, 'm--D', label='OpenMP', ms=markerSize, linewidth=lineSize)
		plt.plot(npcore, line_rsbXT, 'b--o', label='RSBt', ms=markerSize, linewidth=lineSize)
	plt.legend(loc='upper left', prop={'size':20})
	plt.ylim(ymin=0)
	yend = math.ceil(max(max(mkl), max(omp), max(rsb), max(rsbXT)))+1
	if yend > 10:
		plt.yticks(np.arange(0, yend, 2), fontsize=axeSize)
	else:
		plt.yticks(np.arange(0, yend, 1), fontsize=axeSize)
	plt.xticks(np.array(core), fontsize=axeSize)
	plt.ylabel('Speedup', fontsize=labelSize)
	plt.xlabel('# threads', fontsize=labelSize)
	plt.tight_layout()

def plot_fig(now, fig_num, orig, mkl, omp, omp1, rsb, rsbXT):
	savefilename = sys.argv[now]
	if savefilename.find('webspam') != -1:
		savefilename = 'webspam'
	if savefilename.find('rcv1_test.binary') != -1:
		savefilename = 'rcv1_binary'
	elif savefilename.find('rcv1') != -1:
		savefilename = 'rcv1_multiclass'
	if savefilename.find('covtype.libsvm') != -1:
		savefilename = 'covtype_binary'
	elif savefilename.find('covtype') != -1:
		savefilename = 'covtype_multiclass'

	for i in xrange(4):
		for j in xrange(len(core)):
			mkl[i][j] = orig[i]/mkl[i][j]
			omp[i][j] = orig[i]/omp[i][j]
			rsb[i][j] = orig[i]/rsb[i][j]
			rsbXT[i][j] = orig[i]/rsbXT[i][j]

	plot_subfig(0, mkl[0], omp[0], omp1[0], rsb[0], rsbXT[0])
	plt.savefig("./figure/"+savefilename+"-Xv.png")
	plt.close()
	plot_subfig(1, mkl[1], omp[1], omp1[1], rsb[1], rsbXT[1])
	plt.savefig("./figure/"+savefilename+"-XTv.png")
	plt.close()
	plot_subfig(3, mkl[3], omp[3], omp1[3], rsb[3], rsbXT[3])
	plt.savefig("./figure/"+savefilename+"-Train.png")
	plt.close()

if len(sys.argv) < 2:
	print "Usage: python draw.py [dataset]"

logpath = "../log/"
datanum = len(sys.argv)-1

for data in xrange(datanum):
	now = data+1
	origlog = logpath+sys.argv[now]+"-liblr.log"
	origXv, origXTv, origTrain = get_logtime(origlog, 1)
	orig = [origXv, origXTv, origXv+origXTv, origTrain]

	mkllog = logpath+sys.argv[now]+"-liblr-mkl.log"
	mklXv = [0]*len(core)
	mklXTv = [0]*len(core)
	mklMix = [0]*len(core)
	mklTrain = [0]*len(core)
	for i in xrange(len(core)):
		mklXv[i], mklXTv[i], mklTrain[i] = get_logtime(mkllog, core[i])
		mklMix[i] = mklXv[i]+mklXTv[i]
	mkl = [mklXv, mklXTv, mklMix,mklTrain]

	omplog = logpath+sys.argv[now]+"-liblr-omp.log"
	ompXv = [0]*len(core)
	ompXTv = [0]*len(core)
	ompMix = [0]*len(core)
	ompTrain = [0]*len(core)
	for i in xrange(len(core)):
		ompXv[i], ompXTv[i], ompTrain[i] = get_logtime(omplog, core[i])
		ompXv[i] = 100000000000
		ompXTv[i] = 1000000000
		ompMix[i] = ompXv[i]+ompXTv[i]
	omp = [ompXv, ompXTv, ompMix,ompTrain]

	omp1log = logpath+sys.argv[now]+"-liblr-ompsp.log"
	omp1Xv = [0]*len(core)
	omp1XTv = [0]*len(core)
	omp1Mix = [0]*len(core)
	omp1Train = [0]*len(core)
	for i in xrange(len(core)):
		omp1Xv[i], omp1XTv[i], omp1Train[i] = get_logtime(omp1log, core[i])
		omp1Mix[i] = omp1Xv[i]+omp1XTv[i]
	omp1 = [omp1Xv, omp1XTv, omp1Mix,omp1Train]

	rsblog = logpath+sys.argv[now]+"-liblr-rsb.log"
	rsbXv = [0]*len(core)
	rsbXTv = [0]*len(core)
	rsbMix = [0]*len(core)
	rsbTrain = [0]*len(core)
	for i in xrange(len(core)):
		rsbXv[i], rsbXTv[i], rsbTrain[i] = get_logtime(rsblog, core[i])
		rsbMix[i] = rsbXv[i]+rsbXTv[i]
	rsb = [rsbXv, rsbXTv, rsbMix,rsbTrain]

	rsbXTlog = logpath+sys.argv[now]+"-liblr-rsbXT.log"
	rsbXTXv = [0]*len(core)
	rsbXTXTv = [0]*len(core)
	rsbXTMix = [0]*len(core)
	rsbXTTrain = [0]*len(core)
	for i in xrange(len(core)):
		rsbXTXv[i], rsbXTXTv[i], rsbXTTrain[i] = get_logtime(rsbXTlog, core[i])
		rsbXTMix[i] = rsbXTXv[i]+rsbXTXTv[i]
	rsbXT = [rsbXTXv, rsbXTXTv, rsbXTMix,rsbXTTrain]
	
	plot_fig(now, data, orig, mkl, omp, omp1, rsb, rsbXT)

"""
if "Xv" in sys.argv:
	plot_fig("Xv", origXv, mklXv, rsbXv)
if "XTv" in sys.argv:
	plot_fig("XTv", origXTv, mklXTv, rsbXTv)
if "Train" in sys.argv:
	plot_fig("Train", origTrain, mklTrain, rsbTrain)
if "All" in sys.argv:
	plot_fig("Xv", origXv, mklXv, rsbXv)
	plot_fig("XTv", origXv, mklXTv, rsbXTv)
	plot_fig("Train", origTrain, mklTrain, rsbTrain)
"""
