Skip to content

Factorized Tensor slower than Neural Network Layer !!! #35

@KaidDuong

Description

@KaidDuong
import tltorch
import torch
from torch.profiler import profile, record_function, ProfilerActivity

data = torch.randn((4, 16), dtype=torch.float32)
linear = torch.nn.Linear(16, 10)

fact_linear = tltorch.FactorizedLinear.from_linear(linear, auto_tensorize=False,
                    in_tensorized_features=(4, 4), out_tensorized_features=(2, 5), rank=0.1, factorization="tucker")

data = data.to("cuda")
linear = linear.to("cuda")
fact_linear = fact_linear.to("cuda")
with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        linear(data)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        16.99%       1.054ms        99.71%       6.186ms       6.186ms       0.000us         0.00%       4.000us       4.000us             1  
                                           aten::linear         0.29%      18.000us        82.72%       5.132ms       5.132ms       0.000us         0.00%       4.000us       4.000us             1  
                                            aten::addmm        60.54%       3.756ms        81.43%       5.052ms       5.052ms       4.000us       100.00%       4.000us       4.000us             1  
void gemmSN_TN_kernel<float, 128, 16, 2, 4, 4, 4, tr...         0.00%       0.000us         0.00%       0.000us       0.000us       4.000us       100.00%       4.000us       4.000us             1  
                                                aten::t         0.63%      39.000us         1.00%      62.000us      62.000us       0.000us         0.00%       0.000us       0.000us             1  
                                        aten::transpose         0.24%      15.000us         0.37%      23.000us      23.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       aten::as_strided         0.13%       8.000us         0.13%       8.000us       8.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        20.89%       1.296ms        20.89%       1.296ms       1.296ms       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize         0.29%      18.000us         0.29%      18.000us      18.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.204ms
Self CUDA time total: 4.000us

with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        fact_linear(data)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        15.19%       1.098ms        99.79%       7.215ms       7.215ms       0.000us         0.00%      27.000us      27.000us             1  
                                           aten::matmul         0.40%      29.000us        62.63%       4.528ms       1.132ms       0.000us         0.00%      12.000us       3.000us             4  
                                               aten::mm        48.37%       3.497ms        62.23%       4.499ms       1.125ms      12.000us        44.44%      12.000us       3.000us             4  
                                          aten::reshape         0.91%      66.000us         6.10%     441.000us      44.100us       0.000us         0.00%      10.000us       1.000us            10  
                                            aten::clone         0.55%      40.000us         3.91%     283.000us      94.333us       0.000us         0.00%      10.000us       3.333us             3  
                                            aten::copy_         1.40%     101.000us         2.28%     165.000us      55.000us      10.000us        37.04%      10.000us       3.333us             3  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us        37.04%      10.000us       3.333us             3  
void gemmk1_kernel<int, float, 256, 5, false, false,...         0.00%       0.000us         0.00%       0.000us       0.000us       9.000us        33.33%       9.000us       3.000us             3  
                                           aten::linear         0.36%      26.000us         2.23%     161.000us     161.000us       0.000us         0.00%       5.000us       5.000us             1  
                                            aten::addmm         1.00%      72.000us         1.27%      92.000us      92.000us       5.000us        18.52%       5.000us       5.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.230ms
Self CUDA time total: 27.000us

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions