import os
import numpy as np

from sparsekmeans import LloydKmeans, ElkanKmeans
from sklearn.cluster import KMeans

import time
import argparse
import pickle as pkl

import urllib.request
import zipfile

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, help='name of datasets to run cluster')
parser.add_argument('--n_clusters', type=int, help='number of clusters')
parser.add_argument('--n_threads', type=int, help='number of threads', default=max(1, os.cpu_count() // 2))
parser.add_argument('--n_iter', type=int, help='number of clusters', default=300)
parser.add_argument('--backend', type=str, help='running with different backends (sparse-lloyd, sparse-elkan, sklearn-lloyd, sklearn-elkan)')

if __name__ == "__main__":

    np.random.seed(42)  # RandomState
    random_state = np.random.randint(2**31 - 1)
    opt = parser.parse_args()
    
    dataset_name = opt.dataset
    n_clusters = opt.n_clusters
    n_threads = opt.n_threads
    n_iter = opt.n_iter
    backend = opt.backend

    if os.path.exists(f"./data/pkl/{dataset_name}.pkl"):
        with open(f"./data/pkl/{dataset_name}.pkl", "rb") as f:
            label_representation = pkl.load(f)
    else:
        url = f"https://www.csie.ntu.edu.tw/~cjlin/datasets/sparse_kmeans/{dataset_name}.zip"
        output_dir = "./data"
        zip_path = os.path.join(output_dir, f"{dataset_name}.zip")
        os.makedirs(output_dir, exist_ok=True)

        urllib.request.urlretrieve(url, zip_path)

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(output_dir)

    print("------------------Running clustering------------------")
    print("Dataset: ", dataset_name, flush=True)
    print("#Cluster: ", n_clusters, flush=True)
    print("#Threads: ", n_threads, flush=True)
    print("Backend: ", backend, flush=True)

    if backend == "sparse-lloyd":
        kmeans = LloydKmeans(
                n_clusters=n_clusters,
                n_threads=n_threads,
                max_iter=n_iter,
                tol=0.0001,
                random_state=random_state,
                verbose=False
                )
        start = time.time()
        kmeans.fit(label_representation)
        end = time.time()

    if backend == "sparse-elkan":
        kmeans = ElkanKmeans(
                n_clusters=n_clusters,
                n_threads=n_threads,
                max_iter=n_iter,
                tol=0.0001,
                random_state=random_state,
                verbose=False
            )
        start = time.time()
        kmeans.fit(label_representation)
        end = time.time()
    
    if backend == "sklearn-lloyd":
        start = time.time()
        metalabels = (
            KMeans(
                n_clusters,
                random_state=random_state,
                n_init=1,
                max_iter=n_iter,
                tol=0.0001,
                algorithm="lloyd",
            )
            .fit(label_representation)
            .labels_
        )
        end = time.time()
    
    if backend == "sklearn-elkan":
        start = time.time()
        metalabels = (
            KMeans(
                n_clusters,
                random_state=random_state,
                n_init=1,
                max_iter=n_iter,
                tol=0.0001,
                algorithm="elkan",
            )
            .fit(label_representation)
            .labels_
        )

    print("Total clustering runtime: ", end - start, flush=True)