Skip to content

Commit db6e37c

Browse files
Release
0 parents  commit db6e37c

25 files changed

+5734
-0
lines changed

LICENSE

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
MIT License
2+
3+
Copyright (c) 2021 Imant Daunhawer
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
22+
23+
Note that individual files (e.g., fid_score.py, inception.py, etc.) derive from
24+
other projects that were publicly accessible at the time of writing and might
25+
have their own licensing.

README.md

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Disentangling Multimodal Variational Autoencoder
2+
3+
Official code to supplement the paper [Self-supervised Disentanglement of
4+
Modality-specific and Shared Factors Improves Multimodal Generative
5+
Models](https://mds.inf.ethz.ch/fileadmin/user_upload/gcpr_daunhawer_camera_ready.pdf)
6+
published at [GCPR
7+
2020](https://link.springer.com/chapter/10.1007/978-3-030-71278-5_33). This
8+
repository contains a pytorch implementation of the disentangling multimodal
9+
variational autoencoder (DMVAE) and the code to run the experiments from our
10+
paper.
11+
12+
## Installation
13+
14+
```bash
15+
# set up environment
16+
$ conda env create -f environment.yml # install dependencies
17+
$ conda activate dmvae # activate environment
18+
```
19+
20+
## Paired MNIST experiment
21+
```bash
22+
$ cd mmmnist
23+
$ ./run_jobs # create dataset and run experiment
24+
$ tensorboard --logdir runs/tmp # monitor training
25+
```
26+
27+
## MNIST/SVHN experiment
28+
```bash
29+
$ cd mnist_svhn
30+
$ python make_mnist_svhn.py # create dataset
31+
$ ./run_jobs # run experiment
32+
$ tensorboard --logdir runs/tmp # monitor training
33+
```
34+
35+
## Post-hoc analysis
36+
37+
The tensorboard logs contain a lot of metrics (likelihood values,
38+
classification accuracies, etc.), but not the complete evaluation; for
39+
instance, they do not include the coherence values nor the the unconditionally
40+
generated samples and FID values with ex-post density estimation. To compute
41+
these, run the post-hoc analysis using the script `post_hoc_analysis.py` or,
42+
more conveniently, using the bash script `post_hoc_analysis_batch` as follows:
43+
```
44+
$ ./post_hoc_analysis_batch <path_to_experiment> <logdir>
45+
```
46+
where `path_to_experiment` is the directory of the experiment (e.g.,
47+
`$PWD/mmmnist`) and `logdir` denotes directory with the logfiles for the
48+
respective experiment (e.g., `$PWD/mmmnist/runs/tmp/version_x`). Results from
49+
the post-hoc analysis are saved to the respective `logdir`. There, you will
50+
find quantitative results in `results.txt` and qualitative results in the form
51+
of png images.
52+
53+
## BibTeX
54+
55+
If you find this project useful, please cite our paper:
56+
```bibtex
57+
@article{daunhawer2020dmvae,
58+
author = {Imant Daunhawer and
59+
Thomas M. Sutter and
60+
Ricards Marcinkevics and
61+
Julia E. Vogt},
62+
title = {Self-supervised Disentanglement of Modality-Specific and Shared Factors
63+
Improves Multimodal Generative Models},
64+
booktitle = {German Conference on Pattern Recognition},
65+
year = {2020},
66+
}
67+
```

abstract_getters.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
class AbstractGetters:
2+
"""
3+
This abstract class defines getter methods that need to be implemented for every multimodal dataset separately.
4+
"""
5+
def get_encs_decs(self, flags, liks):
6+
"""
7+
Getter for lists with encoders and decoders for all modalities.
8+
9+
Args:
10+
flags: argparse.Namespace with input arguments.
11+
liks: List with likelihoods for every modality.
12+
13+
Returns:
14+
Lists with newly initialized encoders and decoders for all modalities.
15+
"""
16+
raise NotImplementedError
17+
18+
def get_img_to_digit_clfs(self, flags):
19+
"""
20+
Getter for the list with pre-trained image-to-digit classifiers.
21+
22+
Args:
23+
flags: argparse.Namespace with input arguments.
24+
25+
Returns:
26+
A list with pre-trained image-to-digit classifiers for all modalities.
27+
"""
28+
raise NotImplementedError
29+
30+
def get_data_loaders(self, batch_size, num_modalities, num_workers, shuffle=True, device="cuda",
31+
random_noise=False):
32+
"""
33+
Getter for train and test set DataLoaders.
34+
35+
Args:
36+
batch_size: Batch size to use when loading data.
37+
num_modalities: Number of modalities.
38+
num_workers: How many subprocesses to use for data loading.
39+
shuffle: Flag identifying whether to shuffle the data.
40+
device: Which device to use for storing tensors, "cuda" (by default) or "cpu".
41+
random_noise: Flag identifying whether to augment images with Gaussian white noise.
42+
43+
Returns:
44+
DataLoader for training and test sets.
45+
"""
46+
raise NotImplementedError

environment.yml

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
name: dmvae
2+
dependencies:
3+
- pip=19.1=py36_0
4+
- python=3.6.3=h6c0c0dc_5
5+
- pip:
6+
- backcall==0.1.0
7+
- certifi==2019.11.28
8+
- chardet==3.0.4
9+
- cycler==0.10.0
10+
- decorator==4.4.1
11+
- idna==2.8
12+
- ipdb==0.12.3
13+
- ipython==7.12.0
14+
- ipython-genutils==0.2.0
15+
- jedi==0.16.0
16+
- jsonpatch==1.25
17+
- jsonpointer==2.0
18+
- kiwisolver==1.1.0
19+
- matplotlib==3.1.3
20+
- numpy==1.18.1
21+
- opencv-python==4.2.0.32
22+
- pandas==1.0.1
23+
- parso==0.6.1
24+
- pexpect==4.8.0
25+
- pickleshare==0.7.5
26+
- Pillow==7.0.0
27+
- prompt-toolkit==3.0.3
28+
- protobuf==3.11.3
29+
- ptyprocess==0.6.0
30+
- Pygments==2.5.2
31+
- pyparsing==2.4.6
32+
- python-dateutil==2.8.1
33+
- pytz==2019.3
34+
- pyzmq==18.1.1
35+
- requests==2.22.0
36+
- scipy==1.3.3
37+
- six==1.14.0
38+
- tensorflow
39+
- tensorboardX==2.0
40+
- torch==1.4.0
41+
- torchfile==0.1.0
42+
- torchnet==0.0.4
43+
- torchvision==0.5.0
44+
- tornado==6.0.3
45+
- tqdm==4.42.1
46+
- traitlets==4.3.3
47+
- urllib3==1.25.8
48+
- visdom==0.1.8.9
49+
- wcwidth==0.1.8
50+
- websocket-client==0.57.0
51+
- dtw==1.4.0
52+
- fastdtw==0.3.4
53+
- scikit-learn==0.22.2

0 commit comments

Comments
 (0)