{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Hyperparameter Search for Tree-Based Linear Method\n<div class=\"alert alert-danger\"><h4>Warning</h4><p>If you are using the one-vs-rest linear methods,\n    please check [Hyperparameter Search for One-vs-rest Linear Methods](../auto_examples/plot_linear_gridsearch_tutorial.html).</p></div>\n\nTo apply tree-based linear methods,\nwe first convert raw text into numerical TF-IDF features.\nDuring training, the method builds a label tree and trains linear classifiers.\nAt inference, the model traverses the tree and selects\nonly a few candidate labels at each level to speed up prediction.\n\nTo improve model performance, we need to search the hyperparameter space.\nTherefore, in this guide, we help users tune the hyperparameters of the tree-based linear method.\n\n.. seealso::\n\n    [Implementation Document](https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf):\n        For more details about the implementation of tree-based linear methods and hyperparameter search.\n\nHere we show an example of tuning a tree-based linear text classifier with the [rcv1 dataset](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html#rcv1v2%20(topics;%20full%20sets)).\nStarting with loading the data:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import logging\n\nfrom libmultilabel import linear\n\nlogging.basicConfig(level=logging.INFO)\n\ndatasets = linear.load_dataset(\"txt\", \"data/rcv1/train.txt\", \"data/rcv1/test.txt\")\nL = len(datasets[\"train\"][\"y\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next, we set up the search space.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\n\ndmax = 10\nK_factors = [-2, 5]\nsearch_space_dict = {\n    \"ngram_range\": [(1, 1), (1, 2), (1, 3)],\n    \"stop_words\": [\"english\"],\n    \"dmax\": [dmax],\n    \"K\": [max(2, int(np.round(np.power(L, 1 / dmax) * np.power(2.0, alpha) + 0.5))) for alpha in K_factors],\n    \"s\": [1],\n    \"c\": [0.5, 1, 2],\n    \"B\": [1],\n    \"beam_width\": [10],\n    \"prob_A\": [3]\n}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Following the suggestions in the [implementation document](https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf),\nwe define 18 configurations to build a simple yet strong baseline.\n\nThe search space covers several key parts of the search process:\n\n- Text feature extraction: (``ngram_range``, ``stop_words``)\n\n      - We use the vectorizer ``TfidfVectorizer`` from ``sklearn`` to generate features from raw text.\n\n- Label tree structure: (``dmax``, ``K``)\n\n     - The depth and node degree of the label tree. Note that ``K`` is the number of clusters and is calculated using the formula from the [implementation document](https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf).\n\n- Linear classifier: (``s``, ``c``, ``B``)\n\n      - We combined them into a LIBLINEAR option string for training linear classifiers. (see *train Usage* in [liblinear](https://github.com/cjlin1/liblinear)_ README)\n\n- Prediction: (``beam_width``, ``prob_A``)\n\n      - The number of candidates considered and the parameter for the probability estimation function at each level during prediction.\n\n.. tip::\n\n    Available hyperparameters (and their defaults) are defined in the class variables of :py:class:`~libmultilabel.linear.TreeGridParameter`.\n\nIn :py:class:`~libmultilabel.linear.TreeGridSearch`, we perform cross-validation for evaluation.\nSpecifically, we split the training data into ``n_folds``,\nsequentially using each fold as the validation set while training on the remaining folds.\nFinally, we aggregate the validation outputs from each fold and compute the ``monitor_metrics``.\nInitialization requires the dataset, the number of cross-validation folds, and the evaluation metrics.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n_folds = 3\nmonitor_metrics = [\"P@1\", \"P@3\", \"P@5\"]\nsearch = linear.TreeGridSearch(datasets, n_folds, monitor_metrics)\ncv_scores = search(search_space_dict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "``cv_scores`` is a dictionary where keys are :py:class:`~libmultilabel.linear.TreeGridParameter` instances and values are the ``monitor_metrics`` results.\n\nHere we sort the results in descending order by the first metric in ``monitor_metrics``.\nYou can retrieve the best parameters after the grid search with the following code:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sorted_cv_scores = sorted(cv_scores.items(), key=lambda x: x[1][monitor_metrics[0]], reverse=True)\nprint(sorted_cv_scores)\n\nbest_params, best_cv_scores = list(sorted_cv_scores)[0]\nprint(best_params, best_cv_scores)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The best parameters are::\n\n  {'ngram_range': (1, 3), 'stop_words': 'english', 'dmax': 10, 'K': 88, 's': 1, 'c': 1, 'B': 1, 'beam_width': 10, 'prob_A': 3}\n\nwith best cross-validation scores::\n\n  {'P@1': 0.9669, 'P@3': 0.8137, 'P@5': 0.5640}\n\nWe can then retrain using the best parameters,\nand use :py:meth:`~libmultilabel.linear.linear_test` and :py:meth:`~libmultilabel.linear.get_metrics` to compute test performance.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from dataclasses import asdict\n\npreprocessor = linear.Preprocessor(tfidf_params=asdict(best_params.tfidf))\ntransformed_dataset = preprocessor.fit_transform(datasets)\n\nmodel = linear.train_tree(\n    transformed_dataset[\"train\"][\"y\"],\n    transformed_dataset[\"train\"][\"x\"],\n    best_params.linear_options,\n    **asdict(best_params.tree),\n)\n\nmetrics, _, _, _ = linear.linear_test(\n    y = transformed_dataset[\"test\"][\"y\"],\n    x = transformed_dataset[\"test\"][\"x\"],\n    model = model,\n    metrics = linear.get_metrics(monitor_metrics, num_classes=-1),\n    predict_kwargs = asdict(best_params.predict),\n)\n\nprint(metrics.compute())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The result of the best parameters will look similar to::\n\n  {'P@1': 0.9554, 'P@3': 0.7968, 'P@5': 0.5576}\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}