import xml.etree.ElementTree as ET
from bs4 import BeautifulSoup
from pathlib import Path
from os.path import isfile

DOCUMENT_DIR = Path("data/documents/")
DATA_PATH = "data/tag-data.xml"
TAG_FILE = "tags.txt"
PARTITION_FILE = "partition.txt"

# These paths are formatted with 'train' or 'test'
OUTPUT = 'data/wiki10_31k_raw_texts_{}.txt'

LOG_FILE = "preprocess.log"


def get_tags_of_article(article):
    tags = article.find('tags')
    
    if tags == None:
        return []
    
    ret = []
    for tag in tags:
        name = tag.find('name').text.lower()
        names = name.split(',')
        ret += names
        
    return ret

def parse_file(path):
    with open(path) as fp:
        soup = BeautifulSoup(fp, 'html.parser')

    divs = soup.find_all('div', id='bodyContent')
    assert len(divs) == 1
    div = divs[0]
    
    paragraphs = div.find_all(['p']) 

    text = " ".join([p.get_text() for p in paragraphs])
    
    return text.translate(str.maketrans("\n\t", "  "))

def get_tag_map():
    # build the map from tag to index
    tag_map_31k = {}
    with open(TAG_FILE) as f:
        i = 0
        tag = f.readline()
        while tag:  
            tag = tag[:-1] # remove new line
            if tag not in tag_map_31k:
                tag_map_31k[tag] = i
                i += 1
            tag = f.readline()
    
    return tag_map_31k

def main():
    tree = ET.parse(DATA_PATH)
    articles = tree.getroot()

    tag_map_31k = get_tag_map()

    train_output = open(OUTPUT.format('train'), 'w')
    test_output = open(OUTPUT.format('test'), 'w')
    partition = open(PARTITION_FILE, 'r')
    log = open(LOG_FILE, 'a')

    is_train = (partition.readline() == "0\n")
    for i,article in enumerate(articles):
        name = article.find('hash').text
        path = DOCUMENT_DIR / name
        labels_text = get_tags_of_article(article)
        
        if len(labels_text) == 0:
            log.write(f"WARNING: article with hash {name} has no labels.\n")
            
        if not isfile(path):
            log.write(f"WARNING: cannot open file {path}, ignoring it.\n")
            print("\r" + str(i), end="")
            continue
        
        # transform labels from texts to numbers
        labels = set()
        for l in labels_text:
            if l in tag_map_31k:
                labels.add(tag_map_31k[l])
        
        labels = sorted(list(labels))
        labels = " ".join(map(str, labels))
            
        text = parse_file(path)
        
        output = train_output if is_train else test_output
        output.write(labels + "\t" + text + "\n")

        is_train = (partition.readline() == "0\n")
        print("\r" + str(i), end="")


    for file in [train_output, test_output, partition, log]:
        file.close()

    print("\nRaw text and label extraction is done.")


    
if __name__ == "__main__":
    main()


