Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 75 additions & 80 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,75 @@
# Monte Carlo Tree Search for Classifier Chain
# Monte Carlo Tree Search for Classifier Chains

This repository contains the source code of the implementation of MCTS for Classifier Chains with several [examples](./examples/) of how to use it with Classifier Chains. We only support models which compute probabilities for now. See my Bachelor Thesis report for a detailed explaination of the method.
[![Tests](https://github.com/rompoggi/MCTS_ClassifierChain/actions/workflows/tests.yml/badge.svg)](https://github.com/rompoggi/MCTS_ClassifierChain/actions)
[![codecov](https://codecov.io/gh/rompoggi/MCTS_ClassifierChain/graph/badge.svg?token=N9FSNH021E)](https://codecov.io/gh/rompoggi/MCTS_ClassifierChain)
[![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg)](https://opensource.org/licenses/MIT)

This repository provides an implementation of Monte Carlo Tree Search (MCTS) for inference in Multi-Label Classifier Chains, a novel approach developed as part of a Bachelor Thesis at Ecole Polytechnique.

Classifier Chains are a popular method for multi-label classification, but they traditionally use a greedy approach for inference, which can lead to suboptimal predictions. This project frames the inference problem as a search problem and uses MCTS to explore the label space more intelligently, leading to significant performance improvements over the greedy baseline and achieving results competitive with state-of-the-art methods.

For a detailed explanation of the method, please see the full **[Bachelor Thesis Report](https://drive.google.com/file/d/1-gmiogobxYQINJDOgnwJ1kZrVSOHIX2b/view?usp=sharing)**.

## Key Features

* **Novel Inference Strategy**: A new application of Monte Carlo Tree Search to improve predictions for Classifier Chains.
* **High Performance**: Outperforms the standard greedy Classifier Chain and achieves results competitive with state-of-the-art methods like Monte Carlo Classifier Chains (MCC).
* **Flexible Policies**: Easily experiment with different MCTS selection and exploration policies, such as UCB and Epsilon-Greedy.
* **Visualization Tools**: Includes tools to visualize the MCTS search tree, providing insight into the decision-making process.

The repository is part of my Bachelor Thesis submitted for the degree of Bachelor in Mathenmatics and Computer Science at Ecole Polytechnique. It consists of an 8 to 10 week long full time research internship following a topic linked to one of our double major. I was under supervision of Professor Jesse READ, from LIX. See his [webpage](https://jmread.github.io/index.html) for more details about his works in research and teaching.
## Quick Start

#### Monte Carlo Tree Search for Multi-Dimensional Learning with Classifier Chains
*Romain Poggi*, Bachelor of Science at Ecole Polytechnique <br>
*Jesse Read*, Computer Science Laboratory of the École polytechnique <br>
*Bachelor Thesis Report*, [https://drive.google.com/file/d/1-gmiogobxYQINJDOgnwJ1kZrVSOHIX2b/view?usp=sharing](https://drive.google.com/file/d/1-gmiogobxYQINJDOgnwJ1kZrVSOHIX2b/view?usp=sharing)
The following example shows how to train a `ClassifierChain` and use MCTS for inference on a synthetic dataset.

```python
from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import train_test_split
from sklearn.multioutput import ClassifierChain
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import hamming_loss

MCTS for Classifier Chains makes use of the Monte Carlo Tree Search algorithm, a heuristic search algorithm used in decision-making processes. We are see inference as search, where a path is a sequence of labels.
from mcts_inference import MCTS, MCTSConfig, Constraint
from mcts_inference.policy import UCB

# 1. Create a synthetic dataset
X, Y = make_multilabel_classification(n_samples=100, n_features=20, n_classes=5, n_labels=2, random_state=0)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)

# 2. Train a standard Classifier Chain
base_classifier = LogisticRegression(solver="liblinear")
chain = ClassifierChain(base_classifier).fit(X_train, Y_train)

# 3. Use MCTS for inference
config = MCTSConfig(
n_classes=Y.shape[1],
selection_policy=UCB(c=2.0),
constraint=Constraint(max_iter=True, n_iter=100)
)
y_pred_mcts = MCTS(X_test, chain, config)

# 4. Compare with greedy inference
y_pred_greedy = chain.predict(X_test)

print(f"Hamming Loss (Greedy): {hamming_loss(Y_test, y_pred_greedy):.4f}")
print(f"Hamming Loss (MCTS): {hamming_loss(Y_test, y_pred_mcts):.4f}")
```

## Installation

It builds onto the original classifier chains which usses a greedy policy and choses the next node based on its likelihood, which may is often not optimal.
To get started, clone the repository and install it in editable mode. This will also install all the required dependencies from `requirements.txt`.

We also try to improve the PCC and MCC methods, which respectively find in a brute force manner the bayesian optimal label combination, while the other samples different paths based on the node's marginal probability. The first method might not always terminate due to the exponential nature of the label space, though it is optimal when it does terminate. The MCC method is the current state-of-the-art for Classifier Chains, which we try to attain in this work.
```bash
git clone https://github.com/rompoggi/MCTS_ClassifierChain.git
cd MCTS_ClassifierChain
pip install -e .
```
You may need to use `pip3` depending on your Python installation.

## Results

Here are the rankings obtained from our tests, which were made in the [data](./data/) directory, precisely in the [evaluation.ipynb](/data/evaluation.ipynb) notebook. For more information on how to reproduce the obtained results, please refer to [data/README.md](/data/README.md).
The MCTS-based approach was benchmarked against several other methods, including standard Classifier Chains (CC), Probabilistic Classifier Chains (PCC), and Monte Carlo Classifier Chains (MCC). The tables below show the average performance rankings across multiple datasets. Our method (`MUCB(2)`) achieves the second-best performance, close to the state-of-the-art, without extensive hyperparameter tuning.

For details on how to reproduce these results, please refer to the notebooks in the [`data/`](./data/) directory.

<div align="center"><strong>Ranking by Exact Match Score</strong></div>

Expand Down Expand Up @@ -46,89 +97,33 @@ Here are the rankings obtained from our tests, which were made in the [data](./d
| Genbase | 9 | 2 | 1 | 4 | 6 | 5 | 3 | 8 | 7 |
| avg. rank | 4.57 | 3.57| ***2*** | ***3.29*** | 7.29 | 5.14 | 4.29 | 6.429 | 8.0 |

Our method therefore achieves 2nd best performance against state-of-the-art methods, all without tuning the hyperparameters. Thus, this repository invites for contributions to be made to further study the method.

## Repository Overview

There are several directories in this repo:

[src/mcts_inference/](src/mcts_inference/): This directory contains the source code for the project. It includes several modules such as `constraints`, `mcts`, `policy`, `utils`, `mcc`, and `pcc`. These modules likely contain the implementation of the Monte Carlo Tree Search (MCTS), Monte Carlo Classifier Chains (MCC), and Probabilistic Classifier Chains (PCC) algorithms, as well as utility functions and constraints used in the project.
* [`src/mcts_inference/`](./src/mcts_inference/): Source code for the MCTS implementation, policies, and related utilities.
* [`examples/`](./examples/): Jupyter notebooks demonstrating how to use the library and comparing it with other methods.
* [`data/`](./data/): Datasets, preprocessing notebooks, and evaluation results.
* [`tests/`](./tests/): Unit tests for the project.

[examples/](examples/): This directory contains notebooks that demonstrate how to use the framework built in the project. The notebooks include [mcts.ipynb](/examples/mcts.ipynb), [mcts_vs_mcc.ipynb](/examples/mcts_vs_mcc.ipynb), and [mcts_vs_others.ipynb](/examples/mcts_vs_others.ipynb).

[data/](data/): This directory contains the datasets used to evaluate the methods implemented in the project. It includes raw datasets in `.csv `or `.arff` formats, preprocessed datasets, and results of evaluations. The preprocessing and evaluation are done using the notebooks [data_preprocessing.ipynb](/data/data_preprocessing.ipynb) and [evaluation.ipynb](/data/evaluation.ipynb) respectively.

[tests/](tests/): This directory contains the test files for the project. These tests can be run using pytest.

## Installation

There are different ways to use the package in its current form. One can either install it locally as a package, or simply import it in Python files from the same directory.

1. Clone the repository to your local machine:

```bash
git clone https://github.com/rompoggi/MCTS_ClassifierChain.git
```

2. Navigate to the project directory:

```bash
cd MCTS_ClassifierChain
```

3. Install the dependencies:

You can directly install the package with the following command, which would automatically install the requirements listed in the [```requirements.txt```](/requirements.txt) file.

```bash
pip install -e .
```
## Testing

Otherwise, one should install the dependencies via the following command:
The project uses `pytest` for testing. You can run the tests from the root directory:

```bash
pip install -r requirements.txt
pytest
```

One can then use the source code via ```import src.mcts_inference.*```, where * can be replaced by the module you want to use.

Of course, you may use `pip3` instead of `pip` depending on your Python installation.

**Note**: While the core data structures and utilities are well-tested, the main inference functions currently have limited test coverage. Contributions to improve this are welcome.

## Data
The project uses datasets located in the data directory. The raw datasets are in the raw_datasets subdirectory and are in `.csv` or `.arff` formats. The [data_preprocessing.ipynb](/data/data_preprocessing.ipynb) notebook is used to preprocess these datasets and store them in the datasets directory. For more detail, refer the [report](https://drive.google.com/file/d/1-gmiogobxYQINJDOgnwJ1kZrVSOHIX2b/view?usp=sharing) or to the [README](/data/README.md) in the directory.

The [evaluation.ipynb](/data/evaluation.ipynb) notebook is used to run evaluations on the preprocessed data. The results of these evaluations are stored as JSON files in the [.results/](/data/.results/) directory.

## Testing
Tests for the project can be run using pytest. The test files are made accessible in the [tests/](tests/) directory. The testing parameters can be found in the [pyproject.toml](pyproject.toml), [setup.cfg](setup.cfg), and [tox.ini](tox.ini) files. Code-cov reports are also used for future maintenance of the project. We invite contributors to be provide tests when proposing new code.

The project uses GitHub Actions for continuous integration. The configuration used can be found in the [/.github/workflows/tests.yml](/.github/workflows/tests.yml) file.
## Contributing

Contributions to this project are more than welcome. The aim is to further study and improve the method used in this project. Please feel free to open an issue or submit a pull request.

## Contact

Please contact us or post an issue if you have any questions.

For questions related to the package `mcts_inference`:
* *Romain Poggi* ([romain.poggi@polytechnique.edu](romain.poggi@polytechnique.edu) or [romainpoggi323@gmail.com](romainpoggi323@gmail.com))

For questions related to the theoretical aspect of the method:
* *Romain Poggi* ([romain.poggi@polytechnique.edu](romain.poggi@polytechnique.edu) or [romainpoggi323@gmail.com](romainpoggi323@gmail.com))
* *Jesse Read* ([jesse.read@polytechnique.edu](jesse.read@polytechnique.edu))

## Contributing
Contributions to this project are more than welcome. The aim is to further study and improve the method used in this project. Please make sure to follow the coding standards specified in the [`setup.cfg`]("setup.cfg") file under the `[flake8]` section.
For questions about the project, please contact **Romain Poggi** ([romain.poggi@polytechnique.edu](mailto:romain.poggi@polytechnique.edu)).

For questions related to the theoretical aspects of the method, you can also contact **Professor Jesse Read** ([jesse.read@polytechnique.edu](mailto:jesse.read@polytechnique.edu)).

## License
[![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg)](https://opensource.org/licenses/MIT)

This project is licensed under the MIT License. See the [`LICENSE`](LICENSE) file for more details.


## Status

[![Tests](https://github.com/rompoggi/MCTS_ClassifierChain/actions/workflows/tests.yml/badge.svg)](https://github.com/rompoggi/MCTS_ClassifierChain/actions)
[![codecov](https://codecov.io/gh/rompoggi/MCTS_ClassifierChain/graph/badge.svg?token=N9FSNH021E)](https://codecov.io/gh/rompoggi/MCTS_ClassifierChain)
[![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg)](https://opensource.org/licenses/MIT)
This project is licensed under the MIT License. See the [`LICENSE`](./LICENSE) file for more details.
Loading