Skip to content

How to use sparse vae? #61

@wusize

Description

@wusize

Hi, nice work!

I am wondering how to use the sparse VAE to encode and decode meshes, as shown by the code in :

feat, xyz, batch_idx = batch['sparse_sdf'], batch['sparse_index'], batch['batch_idx']

May I know how to prepare the input for the encoder?

I tried the following approach but got unexpected results.

import torch
import trimesh
from omegaconf import OmegaConf
from direct3d_s2.utils import instantiate_from_config

from direct3d_s2.utils.mesh import compute_valid_udf, normalize_mesh


def preprocess(mesh, size=512, device="cuda:0"):
    vertices = torch.Tensor(mesh.vertices).float().to(device) * 0.5
    faces = torch.Tensor(mesh.faces).int().to(device)
    sdf = compute_valid_udf(vertices, faces, dim=size, threshold=4.0)
    sdf = sdf.reshape(size, size, size).unsqueeze(0)

    sparse_index = (sdf < 4/size).nonzero()
    sparse_sdf = sdf[sdf < 4/size]

    return sparse_index, sparse_sdf



mesh = trimesh.load('output_512.obj')

mesh = normalize_mesh(mesh)
# mesh.show()

sparse_index, sparse_sdf = preprocess(mesh, size=512, device="cuda:0")


model_sparse_512_path = 'wushuang98/Direct3D-S2/direct3d-s2-v-1-1/model_sparse_512.ckpt'
config_path = 'wushuang98/Direct3D-S2/direct3d-s2-v-1-1/config.yaml'

cfg = OmegaConf.load(config_path)

state_dict_sparse_512 = torch.load(model_sparse_512_path, map_location='cpu', weights_only=True)
print(f"Load sparse vae 512: {cfg.sparse_vae_512}")

sparse_vae_512 = instantiate_from_config(cfg.sparse_vae_512)
sparse_vae_512.load_state_dict(state_dict_sparse_512["vae"], strict=True)
sparse_vae_512 = sparse_vae_512.eval().to("cuda:0")

dtype = next(sparse_vae_512.parameters()).dtype

sparse_sdf = sparse_sdf.to(dtype)

## encode
batch = dict(sparse_sdf=sparse_sdf, sparse_index=sparse_index[:, 1:], batch_idx=sparse_index[:, 0])


with torch.no_grad():
    z, posterior = sparse_vae_512.encode(batch)


## decode
with torch.no_grad():
    mesh = sparse_vae_512.decode_mesh(latents=z,
                                      voxel_resolution=512,
                                      mc_threshold= 0.2,
                                      return_feat=False,
                                      factor=1.0)[0]

mesh = normalize_mesh(mesh)
mesh.export('reconstructed.obj')

input:

Image

output:

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions