import tensorflow as tf
import h5py
import pandas as pd
import numpy as np
import argparse
import urllib
import os
import tarfile
from tqdm import tqdm

def read_h5(h5_path):

    data_info = pd.read_hdf(h5_path, key="info", mode='r')
    with h5py.File(h5_path, 'r') as hf:
        data_matrix = hf['matrix'][:]
        
    if np.isnan(data_matrix).any():
        print("contains nan, processing...")
        np.nan_to_num(data_matrix, copy=False)
        
    if (data_matrix > 1000).any():
        print("contains > 1000, processing...")
        data_matrix[data_matrix > 1000] = 0
    
    return data_matrix, data_info

def split_train_valid(data_matrix, data_info, split_id="2015000"):
    
    data_info_train = data_info[data_info.ID < split_id]
    data_matrix_train = data_matrix[data_info_train.index]
    data_info_train.index = range(data_info_train.shape[0])
    
    data_info_valid = data_info[data_info.ID >= split_id]
    data_matrix_valid = data_matrix[data_info_valid.index]
    data_info_valid.index = range(data_info_valid.shape[0])
    
    return data_matrix_train, data_info_train, data_matrix_valid, data_info_valid

def normalize(x, mu_std=None, nchannels=4):
    
    if mu_std is None:
        mu = [x[:, :, :, i].mean() for i in range(nchannels)]
        std = [x[:, :, :, i].std() for i in range(nchannels)]
    else:
        mu, std = mu_std
        
    for i in range(nchannels):
        x[:, :, :, i] -= mu[i]
        x[:, :, :, i] /= std[i]
    
    if mu_std is None:
        return x, mu, std
    else:
        return x

def group_by_id(data_matrix, data_info):
    
    id2indices_group = data_info.groupby('ID', sort=False).groups
    indices_groups = list(id2indices_group.values())
    
    data_matrix = [data_matrix[indices] for indices in indices_groups]
    data_info = [data_info.iloc[indices] for indices in indices_groups]
    
    return data_matrix, data_info

def write_tfrecord(data_matrix, data_info, tfrecord_path):
    
    def _bytes_feature(value):
        """Returns a bytes_list from a string / byte."""
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _float_feature(value):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

    def _int64_feature(value):
        """Returns an int64_list from a bool / enum / int / uint."""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def _featurelist(value):
        """Returns an featureList from a list of features."""
        return tf.train.FeatureList(feature=value)
    
    def _encode_tfexample(data_matrix_elem, data_info_elem):
        seqlen_feature = _int64_feature(data_info_elem.shape[0])
        intensity_feature = _bytes_feature(np.ndarray.tobytes(data_info_elem['Vmax'].values))
        flat_img_feature = _bytes_feature(np.ndarray.tobytes(data_matrix_elem))
        features = {
            "seqlen": seqlen_feature,
            "img": flat_img_feature,
            "intensity": intensity_feature,
        }
        return tf.train.Example(features=tf.train.Features(feature=features))
    
    print(f"Saving {tfrecord_path}...")
    
    with tf.python_io.TFRecordWriter(tfrecord_path) as writer:
        assert(len(data_matrix) == len(data_info))
        for data_matrix_elem, data_info_elem in tqdm(zip(data_matrix, data_info)):
            example = _encode_tfexample(data_matrix_elem, data_info_elem)
            serialized = example.SerializeToString()
            writer.write(serialized)

def main():
    parser = argparse.ArgumentParser()
    _ = parser.add_argument("--data_dir", default=".", help="directory to save data, defaults to current directory")
    args = parser.parse_args()
    
    urls = [
        ('TCIR-ATLN_EPAC_WPAC.h5.tar.gz', 'http://140.112.145.151/~boyochen/TCIR/TCIR-ATLN_EPAC_WPAC.h5.tar.gz'),
        ('TCIR-CPAC_IO_SH.h5.tar.gz', 'http://140.112.145.151/~boyochen/TCIR/TCIR-CPAC_IO_SH.h5.tar.gz'),
        ('TCIR-ALL_2017.h5.tar.gz', 'http://140.112.145.151/~boyochen/TCIR/TCIR-ALL_2017.h5.tar.gz')
    ]
    
    for fname, url in urls:
        fpath = os.path.join(args.data_dir, fname)
        if not os.path.isfile(fpath):
            print(f"Downloading {fname}...")
            urllib.request.urlretrieve(url, fpath)

            with tarfile.open(fpath) as f:
                f.extractall(path=args.data_dir)
        else:
            print(f"{fname} already exists...")
    
    # read
    print("Reading h5...")
    data_matrix_1, data_info_1 = read_h5(os.path.join(args.data_dir, 'TCIR-ATLN_EPAC_WPAC.h5'))
    data_matrix_2, data_info_2 = read_h5(os.path.join(args.data_dir, 'TCIR-CPAC_IO_SH.h5'))
    data_matrix_test, data_info_test = read_h5(os.path.join(args.data_dir, 'TCIR-ALL_2017.h5'))
    
    # merge
    print("Merging...")
    data_matrix = np.concatenate((data_matrix_1, data_matrix_2), axis=0)
    data_info = pd.concat([data_info_1, data_info_2])
    data_info.index = range(data_info.shape[0])
    del data_info_1, data_matrix_1, data_info_2, data_matrix_2
    
    # split
    print("Splitting train and valid...")
    data_matrix_train, data_info_train, data_matrix_valid, data_info_valid = split_train_valid(data_matrix, data_info)
    del data_matrix, data_info
    
    # normalize
    print("Normalizing...")
    data_matrix_train, train_mean, train_std = normalize(data_matrix_train)
    data_matrix_valid = normalize(data_matrix_valid, (train_mean, train_std))
    data_matrix_test = normalize(data_matrix_test, (train_mean, train_std))
    
    # group
    print("Grouping by ID...")
    data_matrix_train, data_info_train = group_by_id(data_matrix_train, data_info_train)
    data_matrix_valid, data_info_valid = group_by_id(data_matrix_valid, data_info_valid)
    data_matrix_test, data_info_test = group_by_id(data_matrix_test, data_info_test)
    
    # write
    print("Wrtie to tfrecords...")
    write_tfrecord(data_matrix_train, data_info_train, os.path.join(args.data_dir, 'TCIR-ALL.tfrecord.train'))
    write_tfrecord(data_matrix_valid, data_info_valid, os.path.join(args.data_dir, 'TCIR-ALL.tfrecord.valid'))
    write_tfrecord(data_matrix_test, data_info_test, os.path.join(args.data_dir, 'TCIR-ALL.tfrecord.test'))
    
if __name__ == '__main__':
    main()
