# This script is adapted from https://github.com/iliaschalkidis/lmtc-eurlex57k
import csv
import glob
import json
import logging
import os
import re

import pandas as pd
import tqdm

logging.basicConfig(level=logging.INFO)


DATA_PATH = "data/datasets/EURLEX57K"
OUTPUT_DIR = f"{os.getcwd()}/data/EUR-Lex-57k"
os.makedirs(OUTPUT_DIR, exist_ok=True)


def load_json(path):
    with open(path) as file:
        data = json.load(file)

    # concat sentences to string
    data["main_body"] = " ".join(data["main_body"])
    text = f'{data["header"]} {data["recitals"]} {data["main_body"]} {data["attachments"]}'

    return text, data["concepts"]  # label


if __name__ == "__main__":
    assert os.path.exists(DATA_PATH), """Before running this script, please take the following steps:
    1. Follow the instructions in Chalkidis's README: https://github.com/iliaschalkidis/lmtc-eurlex57k#download-dataset-eurlex57k
    2. Put data directory (i.e., data/datasets/EURLEX57k) to `DATA_PATH` @L15.
    3. Modify `OUTPUT_DIR` to your data directory (e.g., data/EUR-Lex-57k)
    """

    for split in ["train", "dev", "test"]:
        filenames = glob.glob(os.path.join(DATA_PATH, split, "*.json"))

        data = {"label": [], "text": []}
        for filename in tqdm.tqdm(filenames):
            text, label = load_json(filename)
            data["text"].append(text)
            data["label"].append(label)

        split = "valid" if split == "dev" else split
        output_path = os.path.join(OUTPUT_DIR, f"{split}.txt")
        logging.info(f"Generating {output_path} ...")

        df = pd.DataFrame(data=data)

        # preprocess text and label to LibMultiLabel format
        # replace one or more "\s" or "\xa0" (&nbsp) with space
        df["text"] = df["text"].apply(lambda x: re.sub(r"[\s|\xa0]+", " ", x))
        df["text"] = df["text"].str.lower()
        df["label"] = df["label"].str.join(" ")
        df.to_csv(output_path, sep="\t", header=False, quoting=csv.QUOTE_NONE)
