#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer,TfidfTransformer
from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer, _document_frequency
from sklearn.utils.validation import check_array, FLOAT_DTYPES
from sklearn.utils.fixes import _astype_copy_false
import scipy.sparse as sp
import numpy as np
from tqdm import tqdm

# inputs
STRING_FILE = 'original/eurlex_tokenstring.arff'
LABELS_FILE = 'original/id2class_eurlex_eurovoc.qrels'
PERMUTATION_FILE = './perm.pkl'
LABEL_MAP_FILE = 'eurovocs.txt'
TRAIN_SIZE = 15449

# outputs
OUTPUT_PREFIX = './data/'
BOW_OUTPUT = OUTPUT_PREFIX + 'eurlex_tfidf_{}.svm'


def load_arff(filename):
    # a function specifically designed to load the file containing texts
    # return a list of (id, text)
    rows = []
    with open(filename, 'r', encoding='utf-8') as f:
        while (line := f.readline()) != '':
            if line[0] == '@' or line == '\n':
                continue
            num, line = line.split(',', 1)
            string = line.split('"')[1]
            rows.append((int(num), string))
    return rows

def load_label_map(filename):
    # returns a map that maps a string label to its index
    with open(filename, 'r') as f:
        map = {line[:-1]: str(i) for i, line in enumerate(f.readlines())}
    return map

def load_qrels(filename):
    # returns a map that maps a file id to its list of string labels
    labels = {}
    with open(filename, 'r') as f:
        for line in f:
            label, id, _ = line.split()
            id = int(id)
            if id in labels:
                labels[id].append(label)
            else:
                labels[id] = [label]
    return labels

def build_vocabulary(texts):
    vocab = {}
    tokenized = [set(text.split()) for text in texts]
    
    for tokens in tokenized:
        tokens = sorted(tokens)
        for t in tokens:
            if t not in vocab:
                vocab[t] = len(vocab)
                
    return vocab


def fit(self, X, y=None):
    """Learn the idf vector (global term weights).
    Parameters
    ----------
    X : sparse matrix of shape n_samples, n_features)
        A matrix of term/token counts.
    """
    X = check_array(X, accept_sparse=('csr', 'csc'))
    if not sp.issparse(X):
        X = sp.csr_matrix(X)
    dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64

    if self.use_idf:
        n_samples, n_features = X.shape
        df = _document_frequency(X)
        df = df.astype(dtype, **_astype_copy_false(df))

        idf = np.log(n_samples / df)
        self._idf_diag = sp.diags(idf, offsets=0,
                                  shape=(n_features, n_features),
                                  format='csr',
                                  dtype=dtype)

    return self


def main():
    data = load_arff(STRING_FILE)
    label_map = load_label_map(LABEL_MAP_FILE)

    with open(PERMUTATION_FILE, 'rb') as f:
            permutation = pickle.load(f)

    id_labels = load_qrels(LABELS_FILE)
    data = list(filter(lambda x: id_labels.get(x[0]) is not None, data))
    texts = [data[i][1] for i in permutation]
    labels = [id_labels[data[i][0]] for i in permutation]

    text_splits = {'train': texts[:TRAIN_SIZE], 'test': texts[TRAIN_SIZE:]}
    label_splits = {'train': labels[:TRAIN_SIZE], 'test': labels[TRAIN_SIZE:]}

    vocab = build_vocabulary(text_splits['train'])

    TfidfTransformer.fit = fit # monkey patch to use a different idf formula
    vectorizer = TfidfVectorizer(token_pattern=r'(?u)\b\w+\b', vocabulary=vocab)
    vectorizer.fit(text_splits['train'])

    os.makedirs(OUTPUT_PREFIX, exist_ok=True)

    for partition in ['train', 'test']:
        bow_file = BOW_OUTPUT.format(partition)
        print(f"Generating {bow_file}...")
        
        bow_output = open(bow_file, 'w')

        vectors = vectorizer.transform(text_splits[partition])
        vectors = vectors.tolil()
        
        for labels, feature, index in tqdm(list(zip(label_splits[partition], vectors.data, vectors.rows))):
            assert len(feature) == len(index)
            labels_str = ",".join(map(label_map.get, labels))

            # feature index starts from 1 and index 1 is not used.
            feature_str = ' '.join([f'{i+2}:{f}' for f, i in zip(feature, index)])
            bow_output.write(f'{labels_str} {feature_str}\n')
                
        bow_output.close()


if __name__ == '__main__':
    main()