#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import re
from bs4 import BeautifulSoup
import pandas as pd
from tqdm import tqdm
import pickle

import lucene
from org.apache.lucene.analysis.tokenattributes import CharTermAttribute as CharAttr
from org.apache.lucene.analysis import StopFilter, CharArraySet
from org.apache.lucene.analysis.en import PorterStemFilter
from org.apache.lucene.analysis.core import LetterTokenizer, LowerCaseFilter
from org.apache.pylucene.analysis import PythonAnalyzer, PythonFilteringTokenFilter, PythonTokenFilter
from lucene import JArray


LABELS_FILE = 'original/id2class_eurlex_eurovoc.qrels'
ID_MAPPINGS_FILE = 'original/eurlex_ID_mappings.csv'
HTML_DIR = 'original/htmls/'
STOPWORD_FILE = 'english.stop'

# split of the dataset
PERMUTATION_FILE = 'perm.pkl'
TRAIN_SIZE = 15449

# output settings
OUTPUT_PREFIX = './data/'
OUTPUT_FILE = OUTPUT_PREFIX + 'eurlex_raw_texts_{}.txt'


def load_qrels(filename):
    labels = {}
    with open(filename, 'r') as f:
        for line in f:
            label, id, _ = line.split()
            id = int(id)
            if id in labels:
                labels[id].append(label)
            else:
                labels[id] = [label]
    return labels

def parse_html(filename):
    # retrieve the text body
    with open(filename, 'r', encoding='utf-8') as f:
        soup = BeautifulSoup(f, 'html.parser')
        text = soup.find('div', class_="texte").text
    return text


class NoEnglishFilter(PythonFilteringTokenFilter):
    # This fileter filters out all the tokens that does not contain any english letters
    def __init__(self, ts):
        super(NoEnglishFilter, self).__init__(ts)
        
    def accept(self):
        token = self.getAttribute(CharAttr.class_).toString()
        match = re.match(r'[^a-zA-Z]+\b', token)
        return match is None 


class FixSigmaFilter(PythonTokenFilter):
    # This filter is used to make sure the last 'σ' is written as 'ς'
    def __init__(self, ts):
        super(FixSigmaFilter, self).__init__(ts)
        self.ts = ts

    def incrementToken(self):
        if not self.ts.incrementToken():
            return False
        
        token = self.ts.getAttribute(CharAttr.class_)
        string = token.toString()
        if string[-1] == 'σ' and string[-2] != 'd':
            # if the word is a greek word then the last sigma is modified
            # but 'σ' in dH/dσ is not modified
            string = string[:-1] + 'ς'
            string = JArray('char')(string)
            token.copyBuffer(string, 0, len(string))
            
        return True
        
    
class PorterStemmerAnalyzer(PythonAnalyzer):
    def __init__(self):
        super().__init__()

        with open(STOPWORD_FILE, 'r') as f:
            stopword_list = list(map(lambda token: token[0:-1], f.readlines()))

        self.stopwords = CharArraySet(len(stopword_list), True)
        for word in stopword_list:
            self.stopwords.add(word)
            
    def createComponents(self, fieldName):
        source = LetterTokenizer()
        filter = LowerCaseFilter(source)
        filter = StopFilter(filter, self.stopwords)
        filter = PorterStemFilter(filter)
        filter = NoEnglishFilter(filter)
        filter = FixSigmaFilter(filter)
        
    
        return self.TokenStreamComponents(source, filter)

    def initReader(self, fieldName, reader):
        return reader

def replace_all(str, repls):
    return re.sub('|'.join(re.escape(key) for key in repls.keys()),
                  lambda k: repls[k.group(0)], str)  

def preprocess_text(text, analyzer):
    text = replace_all(text, {
        ">": " gt ",
        "<": " lt ",
        '"': " quot ",
        '&': " amp "
    })
    # process the text using lucene
    ts = analyzer.tokenStream("dummy", text) 
    ts.reset()
    first_out = ''
    while ts.incrementToken():
        token = ts.getAttribute(CharAttr.class_).toString()
        first_out += token + ' '
    ts.end()
    ts.close()
    return first_out[:-1]


def main():
    lucene.initVM(vmargs=['-Djava.awt.headless=true'])
    
    df = pd.read_csv(ID_MAPPINGS_FILE, sep='\t')
    id_filename = pd.Series(df.Filename.values, index=df.DocID).to_dict()
    filename_removed = pd.Series(df.remove.values, index=df.Filename).to_dict()
    id_labels = load_qrels(LABELS_FILE)
    file_to_labels = {id_filename[id]:id_labels[id] for id in id_labels.keys()}
    files = sorted(os.listdir(HTML_DIR))
    files = list(filter(lambda file: not filename_removed[file], files))

    analyzer = PorterStemmerAnalyzer()

    data = []
    print("Extracting and preprocessing text from source files...")
    for filename in tqdm(files):
        labels = file_to_labels.get(filename, None)
    
        if labels is None:
            continue
        else:
            text = parse_html(HTML_DIR + filename)
            data.append((preprocess_text(text, analyzer), labels))

    with open(PERMUTATION_FILE, 'rb') as f:
        permutation = pickle.load(f)
    data = [data[i] for i in permutation]

    splits = {'train': data[:TRAIN_SIZE], 'test': data[TRAIN_SIZE:]}

    os.makedirs(OUTPUT_PREFIX, exist_ok=True)

    print("Outputting data...")
    for split, samples in splits.items():
        output_file = open(OUTPUT_FILE.format(split), 'w', encoding='utf-8')
        
        for sample in samples:
            output_file.write(" ".join(sample[1]) + "\t" + sample[0] + '\n')
        
        output_file.close()


if __name__ == "__main__":
    main()

