Skip to content

Added device option to Wasserstein#3

Open
sAbhay wants to merge 1 commit into
dfdazac:masterfrom
sAbhay:master
Open

Added device option to Wasserstein#3
sAbhay wants to merge 1 commit into
dfdazac:masterfrom
sAbhay:master

Conversation

@sAbhay

@sAbhay sAbhay commented Jun 20, 2022

Copy link
Copy Markdown

Updated wasserstein submodule to include device

@dfdazac

dfdazac commented Aug 25, 2022

Copy link
Copy Markdown
Owner

@sAbhay thank you for your contribution! Sorry for the long delay on my response.
To keep compatibility with distributed training, where computations could run on different devices, I think it would be better to grab the device during the forward pass, rather than fixing it during initialization. For example,

def forward(self, x, y):
    device = x.device     
    ...

    mu = torch.empty(batch_size, x_points, dtype=torch.float,
                     requires_grad=False).fill_(1.0 / x_points).squeeze()
                     requires_grad=False, device=device).fill_(1.0 / x_points).squeeze()
    ...

This way, computations will run in whatever device x might be. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants