import argparse
import bz2
import os

from datasets import load_dataset

from sklearn.feature_extraction.text import TfidfVectorizer 

from ocp import config_list, config2task, split_list, split2name


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-dp', '--data_path', type=str, default='lexglue_data')
    args = parser.parse_args()
    return args


def get_texts(config, dataset):
    if 'ecthr' in config:
        texts = [' '.join(text) for text in dataset['text']]
        return [' '.join(text.split()) for text in texts]
    else:
        return [' '.join(text.split()) for text in dataset['text']]


def get_labels(config, dataset, task):
    if task == 'multi_class':
        return list(map(str, dataset['label']))
    else:
        return [' '.join(map(str, label)) for label in dataset['labels']]


def main():
    # args
    args = get_args()
    os.makedirs(args.data_path, exist_ok=True)

    # process
    for config in config_list:
        config_path = os.path.join(args.data_path, config)
        os.makedirs(config_path, exist_ok=True)
        processed_data = {}
        for split in split_list:
            dataset = load_dataset('lex_glue', config, split=split)
            texts = get_texts(config, dataset)
            labels = get_labels(config, dataset, config2task[config])
            assert len(texts) == len(labels)
            processed_data[split] = {'text': texts, 'labels': labels}

        # raw texts (train/validation/test)
        for split in processed_data:
            split_file = f'{config}_lexglue_raw_texts_{split2name[split]}.txt.bz2'
            split_path = os.path.join(config_path, split_file)
            with bz2.open(split_path, 'wb') as f:
                for text, label in zip(processed_data[split]['text'], processed_data[split]['labels']):
                    formatted_instance = '\t'.join([label, text])
                    f.write(f'{formatted_instance}\n'.encode('utf-8'))

        # tfidf (train/test)
        processed_data['train'] = {
            'text': processed_data['train']['text'] + processed_data['validation']['text'],
            'labels': processed_data['train']['labels'] + processed_data['validation']['labels']
        }

        vectorizer = TfidfVectorizer()
        vectorizer.fit(processed_data['train']['text'])

        for split in ['train', 'test']:
            split_file = f'{config}_lexglue_tfidf_{split2name[split]}.svm.bz2'
            split_path = os.path.join(config_path, split_file)

            vectors = vectorizer.transform(processed_data[split]['text'])
            vectors = vectors.tolil()

            with bz2.open(split_path, 'wb') as f:
                for labels, feature, index in zip(processed_data[split]['labels'], vectors.data, vectors.rows):
                    assert len(feature) == len(index)
                    labels_str = labels.replace(' ', ',')
                    feature_str = ' '.join([f'{i+1}:{f}' for f, i in zip(feature, index)])
                    f.write(f'{labels_str} {feature_str}\n'.encode('utf-8'))


if __name__ == '__main__':
    main()
