File size: 1,632 Bytes
3f9c425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import time
import torch
import torch.utils.benchmark as benchmark

from triton.testing import do_bench


def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
    """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
    if verbose:
        print(desc, '- Forward pass')
    t = benchmark.Timer(
            stmt='fn(*inputs, **kwinputs)',
            globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
            num_threads=torch.get_num_threads(),
            )
    m = t.timeit(repeats)
    if verbose:
        print(m)
    return t, m


torch.manual_seed(0)
repeats = 30
dtype = torch.float16
device = 'cuda'
verbose = False
m, n = 8192, 8192

tflops_matmul = {}
tflops_matmul1 = {}
for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
    a = torch.randn(m, k, device=device, dtype=dtype)
    b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
    nFLOPS_matmul = 2 * m * n * k
    time.sleep(2)  # to reduce power throttling
    timing = benchmark_forward(torch.matmul, a, b, desc='cuBLAS', verbose=verbose, repeats=repeats)[1]
    tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
    print(f'[torch.utils.benchmark] cuBLAS, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
    time.sleep(2)  # to reduce power throttling
    ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
    tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
    print(f'[triton.test.do_bench]  cuBLAS, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')