[TVM_0.20] 03 Relax

Relax 상세

Graph Abstraction for ML Models

Graph abstraction은 데이터 flow와 구조를 나타내기 위한 Machine Learning 컴파일러의 주요 기술이다. 모델을 Graph representaion으로 추상화 함으로써 컴파일러는 다양한 최적화나 성능 향상을 수행할 수 있다.

What is Graph Abstraction?

Graph abstraction은 ML model을 Graph로 나타내기 위한 프로세스이며 컴파일러가 모델의 파트 사이의 dependency와 relation을 분석하게 해준다.

  • node : computational operations (e.g., matrix multiplication, convolution)
  • edge : Operation 간 data의 흐름을 나타냄

Key Features of Relax

  • First-class symbolic shape : Relax 는 Tensor의 차원을 표현할 때 Symblic shape를 사용. tensor operators와 function calls 간의 dynamic shape relationship을 전역 추적할 수 있게 함
    • (역자주) First-class는 프로그래밍에서 해당 요소가 함수의 인자 리턴 값으로 자유롭게 사용될 수 있고 컴파일러가 최적화에 활용될 수 있음을 의미, 즉 relax는 동적입력을 지원하고 이것이 컴파일러 단에서 최적화가 가능하며 전체 모델에서 Shape 분석이 가능함을 의미한다. TIR도 Dynamic Shape를 지원하지만 이는 함수에서 단순한 변수 추적일뿐 shape 추론은 수동으로 해야한다. 즉 shape관계추적, 전역 최적화는 어렵다. 즉 Relax는 이 텐서의 크기를 나중에 정할 수 있으며 심볼로 최적화가 가능)
  • Multi-level abstractions : Relax는 high-level neural network layer부터 low-level tensor operation까지 포함하는 cross-level abstraction을 지원
    • (역자주) [Relax] 신경망 레이어 단위(Dense, ReLU, Conv2D) -> [Relax Dataflow] 텐서 연산 단위 (matmul, add, relu) -> [TIR]루프/인덱싱 기반 연산 의 변환이 하나의 통합된 시스템 안에서 자유롭게 오갈 수 있음)
  • Composable transformations : relax는 모델 컴포넌트에 선택적으로 적용가능한 transformation을 제공, partial lowering / partial specialization 같은 유연한 최적화 옵션을 포함
    • (역자주) 한 번에 전체 모델을 “한 가지 방식"으로 변환하거나, 특정 최적화를 적용하면 다른 최적화와 충돌하는 경우가 많으나 relax는 모델 전체가 아닌, 특정 함수/연산에만 적용 가능하다. 즉 모델을 하드웨어/용도에 맞춰 유연하게 최적화하고 커스터마이징 가능)

Understand Relax Abstraction

Relax는 ML모델에 대해 end-to-end optimize를 돕기 위한 graph abstraction. Relax는 ML모델의 structure와 data flow를 묘사한다.(모델 파트간의 dependency와 relationship 및 HW에서 실행되는 방법)

End to End Model Execution

이제부터 linear->relu->linear 모델을 활용하여 Relax를 설명한다.

High-Level Operations Representation

위 모델을 Numpy와 Relax 모델로 하기와 같이 표현할 수 있다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# numpy
def numpy_mlp(data, w0, b0, w1, b1):
    lv0 = data @ w0 + b0
    lv1 = np.maximum(lv0, 0)
    lv2 = lv1 @ w1 + b1
    return lv2

# Relax
from tvm.script import relax as R

@R.function
def relax_mlp(
    data: R.Tensor(("n", 784), dtype="float32"),
    w0: R.Tensor((784, 128), dtype="float32"),
    b0: R.Tensor((128,), dtype="float32"),
    w1: R.Tensor((128, 10), dtype="float32"),
    b1: R.Tensor((10,), dtype="float32"),
) -> R.Tensor(("n", 10), dtype="float32"):
    with R.dataflow():
        lv0 = R.matmul(data, w0) + b0
        lv1 = R.nn.relu(lv0)
        lv2 = R.matmul(lv1, w1) + b1
        R.output(lv2)
    return lv2

Low-Level Integration

머신러닝 컴파일러의 관점에서 array computation의 세부 사항을 살펴보기 위해 Numpy 코드를 low-level로 풀어보자(배열 함수 대신 루프 사용, numpy.empty를 통해 배열을 명시적으로 할당)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def lnumpy_linear(X: np.ndarray, W: np.ndarray, B: np.ndarray, Z: np.ndarray):
    n, m, K = X.shape[0], W.shape[1], X.shape[1]
    Y = np.empty((n, m), dtype="float32")
    for i in range(n):
        for j in range(m):
            for k in range(K):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + X[i, k] * W[k, j]

    for i in range(n):
        for j in range(m):
            Z[i, j] = Y[i, j] + B[j]


def lnumpy_relu0(X: np.ndarray, Y: np.ndarray):
    n, m = X.shape
    for i in range(n):
        for j in range(m):
            Y[i, j] = np.maximum(X[i, j], 0)

def lnumpy_mlp(data, w0, b0, w1, b1):
    n = data.shape[0]
    lv0 = np.empty((n, 128), dtype="float32")
    lnumpy_matmul(data, w0, b0, lv0)

    lv1 = np.empty((n, 128), dtype="float32")
    lnumpy_relu(lv0, lv1)

    out = np.empty((n, 10), dtype="float32")
    lnumpy_matmul(lv1, w1, b1, out)
    return out

위의 코드를 활용해 Relax로 표현할 수 있다(TVMScript 구현)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@I.ir_module
class Module:
    @T.prim_func(private=True)
    def linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
        M, N, K = T.int64(), T.int64(), T.int64()
        X = T.match_buffer(x, (M, K), "float32")
        W = T.match_buffer(w, (K, N), "float32")
        B = T.match_buffer(b, (N,), "float32")
        Z = T.match_buffer(z, (M, N), "float32")
        Y = T.alloc_buffer((M, N), "float32")
        for i, j, k in T.grid(M, N, K):
            with T.block("Y"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[v_i, v_j] = T.float32(0.0)
                Y[v_i, v_j] = Y[v_i, v_j] + X[v_i, v_k] * W[v_k, v_j]
        for i, j in T.grid(M, N):
            with T.block("Z"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                Z[v_i, v_j] = Y[v_i, v_j] + B[v_j]

    @T.prim_func(private=True)
    def relu(x: T.handle, y: T.handle):
        M, N = T.int64(), T.int64()
        X = T.match_buffer(x, (M, N), "float32")
        Y = T.match_buffer(y, (M, N), "float32")
        for i, j in T.grid(M, N):
            with T.block("Y"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                Y[v_i, v_j] = T.max(X[v_i, v_j], T.float32(0.0))

    @R.function
    def main(
        x: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((784, 256), dtype="float32"),
        b0: R.Tensor((256,), dtype="float32"),
        w1: R.Tensor((256, 10), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32")
    ) -> R.Tensor(("n", 10), dtype="float32"):
        cls = Module
        n = T.int64()
        with R.dataflow():
            lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))
            lv1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 256), dtype="float32"))
            lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((b, 10), dtype="float32"))
            R.output(lv2)
        return lv2

위으 코드는 primitive tensor functions (T.prim_func) 과 R.function (relax function)을 포함(Relax 함수는 high-level neural network execution를 나타내기 위한 새로운 타입)
Relax Module이 symbolic shape를 지원하는 것이 중요 (main 함수의 n, linear 함수의 M,N,K), 이것은 tensor operators와 function calls 간의 dynamic shape 추적을 위한 key feature임
numpy 코드와 TVMScript를 1:1로 비교해보면 세부적인 사항을 알 수 있다.

Key Elements of Relax

Structure Info

Structure info는 relax expression의 type을 표현하기 위한 새로운 컨셉 ( TensorStructInfo, TupleStructInfo 등)

  • 위의 예제어서는 inputs, outputs, 중간 결과의 shape와 dtype을 표현하기 위해 TensorStructInfo (R.Tensor)를 사용

R.call_tir

R.call_tir 함수는 primitive tensor 함수 호출을 위한 새로운 추상화, cross-level abstraction을 위한 key feature.

1
2
3
4
5
6
#Relax
lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))

#Numpy
lv0 = np.empty((n, 256), dtype="float32")
lnumpy_linear(x, w0, b0, lv0)

위의 relax코드와 이에 대응하는 numpy 코드를 비교해보자. call_tir은 destination passing을 사용한다.

  • input / output은 low-level primitive function 외부에 명시적으로 할당(저수준 라이브러리 설계에서 일반적으로 사용되는 방법으로 고수준 프레임워크가 메모리 할당 결정을 처리할 수 있음)
  • 모든 텐서 연산을 이 스타일로 표현할 수 있는 것은 아니나 (예를 들어 입력에 따라 출력 형태가 달라지는 연산) 일반적으로 가능하면 저수준 함수를 이 스타일로 작성하는 것이 일반적

Dataflow Block

relax function의 중요한 다른 element는 R.dataflow()

1
2
3
4
5
with R.dataflow():
    lv = R.call_tir(cls.linear, (x, w0, b0), out_sinfo=R.Tensor((n, 256), dtype="float32"))
    lv1 = R.call_tir(cls.relu, (lv0,), out_sinfo=R.Tensor((n, 256), dtype="float32"))
    lv2 = R.call_tir(cls.linear, (lv1, w1, b1), out_sinfo=R.Tensor((b, 10), dtype="float32"))
    R.output(lv2)

Relax의 dataflow를 설명하기 전 pure와 side-effect의 개념에 대해 알아야 한다

  • pure 함수 : 입력만을 읽고, 출력을 만들어내는 함수 (입력 및 외부 메모리 영역을 변경하지 않음)
  • Side-effect : 함수가 단순히 결과를 반환하는 것 외에, 프로그램의 다른 부분(메모리, 전역 변수 등)에 영향을 미치는 것
  • 즉 pure(side-effect free) 하다는 것은 입력을 읽어 출력을 내보낼때 입력이나 다른 외부 메모리를 변경하지 않는다.(inplace operations(A += 1)은 side-effet가 발생)

dataflow block은 side-effect free 함수만 허용함, side-effect가 있는 함수는 dataflow block에서 처리해야함

  • (역자주) Relax는 모델의 순수 계산(graph) 과 운영/제어(control flow) 를 명확히 분리하려고 설계되었다. 그래서 R.dataflow() 내부는 최적화에 최적화된 구간, 외부는 학습/관리/제어 코드가 들어가는 구간으로 구분하는 것이 자연스러운 패턴

Dataflow Block을 자동으로 나누지 않고 수동으로 표시해야 하는 이유는

  • auto inference 는 부정확 할수 있음 : packed function 호출(cuBLAS, cuDNN 등 외부 라이브러리 호출) 같은 경우 컴파일러 입장에서 확실하게 pure or not을 판단하기 어렵다.
  • 많은 최적화는 Dataflow Block 안에서만 가능 : Fusion optimization 등은 Dataflow Block안에서만 가능(pure function 들만 모여 있기 때문에 연산 순서를 바꾸거나 합치는(fusion) 것이 안전합니다.)
  • (역자 주) 컴파일러가 잘못 Dataflow Block 경계를 잡으면 최적화가 잘못되거나 성능에 영항을 줄 수 있다.)

Relax Creation

Relax functions을 정의하는 다양한 방법에 대해 다룬다.

Create Relax programs using TVMScript

TVMScript는 TVM IR을 표현하기 위한 domain-specific language이다.

  • python 형태의 언어이며 TensorIR and Relax function 둘다 포함한다.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from tvm import relax, topi
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T


@I.ir_module
class RelaxModule:
    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.nn.relu(lv0)
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2

RelaxModule.show() #출력확인

Relax는 graph-level IR 뿐만 아니라 cross-level representation과 transformation도 지원한다. 구체적으로 말하자면 Relax 함수에서 TensorIR 함수를 직접 호출할 수 있다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
@I.ir_module
class RelaxModuleWithTIR:
    @T.prim_func
    def relu(x: T.handle, y: T.handle):
        n, m = T.int64(), T.int64()
        X = T.match_buffer(x, (n, m), "float32")
        Y = T.match_buffer(y, (n, m), "float32")
        for i, j in T.grid(n, m):
            with T.block("relu"):
                vi, vj = T.axis.remap("SS", [i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @R.function
    def forward(
        data: R.Tensor(("n", 784), dtype="float32"),
        w0: R.Tensor((128, 784), dtype="float32"),
        b0: R.Tensor((128,), dtype="float32"),
        w1: R.Tensor((10, 128), dtype="float32"),
        b1: R.Tensor((10,), dtype="float32"),
    ) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        cls = RelaxModuleWithTIR
        with R.dataflow():
            lv0 = R.matmul(data, R.permute_dims(w0)) + b0
            lv1 = R.call_tir(cls.relu, lv0, R.Tensor((n, 128), dtype="float32"))
            lv2 = R.matmul(lv1, R.permute_dims(w1)) + b1
            R.output(lv2)
        return lv2

RelaxModuleWithTIR.show() #출력확인

show()로 출력을 확인해보면 작성한 TVMScript 코드와 출력이 다름을 볼수 있는데 이는 출력 시 syntax sugar 등이 표준 포맷으로 출력되기 때문이다. 예를 들어 작성 시 한라인에 여러 operation을 결합하여 작성할 수 있으나 출력시에는 한라인에 하나의 오퍼레이션이 결합되도록 출력된다.

1
2
3
4
5
6
7
# writen
lv0 = R.matmul(data, R.permute_dims(w0)) + b0

# printed
lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((n, 128), dtype="float32") = R.matmul(data, lv, out_dtype="void")
lv0: R.Tensor((n, 128), dtype="float32") = R.add(lv1, b0)

Create Relax programs using NNModule API

TVM은 Relax 프로그래밍을 위한 PyTorch-like API인 Relax NNModule API 지원한다. NNModule을 정의한 후 이를 export_tvm을 활용하여 TVM IRModule로 변환할 수 있다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
from tvm.relax.frontend import nn

class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x

mod, params = NNModule().export_tvm({"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}})
mod.show()

또한 NNModule에 customized function call을 삽입할 수 있다.

  • Tensor Expression(TE), TensorIR functions, other TVM packed functions 등
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@T.prim_func
def tir_linear(x: T.handle, w: T.handle, b: T.handle, z: T.handle):
    M, N, K = T.int64(), T.int64(), T.int64()
    X = T.match_buffer(x, (M, K), "float32")
    W = T.match_buffer(w, (N, K), "float32")
    B = T.match_buffer(b, (N,), "float32")
    Z = T.match_buffer(z, (M, N), "float32")
    for i, j, k in T.grid(M, N, K):
        with T.block("linear"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                Z[vi, vj] = 0
            Z[vi, vj] = Z[vi, vj] + X[vi, vk] * W[vj, vk]
    for i, j in T.grid(M, N):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            Z[vi, vj] = Z[vi, vj] + B[vj]


class NNModuleWithTIR(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        n = x.shape[0]
        # We can call external functions using nn.extern
        x = nn.extern(
            "env.linear",
            [x, self.fc1.weight, self.fc1.bias],
            out=nn.Tensor.placeholder((n, 128), "float32"),
        )
        # We can also call TensorIR via Tensor Expression API in TOPI
        x = nn.tensor_expr_op(topi.nn.relu, "relu", [x])
        # We can also call other TVM packed functions
        x = nn.tensor_ir_op(
            tir_linear,
            "tir_linear",
            [x, self.fc2.weight, self.fc2.bias],
            out=nn.Tensor.placeholder((n, 10), "float32"),
        )
        return x


mod, params = NNModuleWithTIR().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
mod.show()

Create Relax programs using Block Builder API

TVM은 Relax 프로그래밍을 위한 Block Builder API 를 제공한다. 이는 IR builder API로 좀 더 low-level이며 customized pass를 기술하기 위한 TVM 내부 로직에 널리 쓰인다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
bb = relax.BlockBuilder()
n = T.int64()
x = relax.Var("x", R.Tensor((n, 784), "float32"))
fc1_weight = relax.Var("fc1_weight", R.Tensor((128, 784), "float32"))
fc1_bias = relax.Var("fc1_bias", R.Tensor((128,), "float32"))
fc2_weight = relax.Var("fc2_weight", R.Tensor((10, 128), "float32"))
fc2_bias = relax.Var("fc2_bias", R.Tensor((10,), "float32"))
with bb.function("forward", [x, fc1_weight, fc1_bias, fc2_weight, fc2_bias]):
    with bb.dataflow():
        lv0 = bb.emit(relax.op.matmul(x, relax.op.permute_dims(fc1_weight)) + fc1_bias)
        lv1 = bb.emit(relax.op.nn.relu(lv0))
        gv = bb.emit(relax.op.matmul(lv1, relax.op.permute_dims(fc2_weight)) + fc2_bias)
        bb.emit_output(gv)
    bb.emit_func_output(gv)

mod = bb.get()
mod.show()

Block Builder API는 유저 친화적이지 않지만 가장 낮은 수준의 API이며 IR definition과 밀접하게 작동한다. TVM은 ML 모델을 정의하거나 transform하고자 하는 사용자들은 TVMScript나 NNModule API를 사용하는 것을 추천한다. 그러나 복잡한 transformation을 원한다면 Block Builder API가 좀 더 유연한 선택이다.

Transformation

Transformation은 Hardware Backend와 최적화 및 통합하기 위한 컴파일 flow의 핵심 요소이다. 2항의 NNModule API 예제인 class NNModule(nn.Module)을 이용하여 예제를 진행한다.

Apply transformations

Pass는 Transformation을 Relax 프로그램에 적용하기 위한 주요 방법이다. 첫번째 단계로 built-in pass인 LegalizeOps를 적용하여 high-level operator들을 low-level operator로 lowering 할 수 있다. (본문을 통해 pass가 적용된 결과를 확인하면 add, matmul등이 TensorIR로 변환되고 R.call_tir을 통해 호출되는 형태로 변환된 것을 확인할 수 있다.)

1
2
mod = tvm.relax.transform.LegalizeOps()(origin_mod)
mod.show()

결과로 부터 high-level operator(aka relax.op)가 이에 대응되는 low-level operator(aka relax.call_tir)로 교체된 것을 볼 수 있다. fusion optimization(연산자 융합)은 Pass의 집합을 적용하여 수행할 수 있다.

1
2
3
4
5
6
7
8
mod = tvm.ir.transform.Sequential(
    [
        tvm.relax.transform.AnnotateTIROpPattern(),
        tvm.relax.transform.FuseOps(),
        tvm.relax.transform.FuseTIR(),
    ]
)(mod)
mod.show()

본문의 결과로 부터 matmul, add, relu 연산자가 하나의 커널(aka call_tir)로 합쳐진 것을 볼 수 있다. 지원하는 Built-in pass들은 relax.transform을 참조하면 확인할 수 있다.

Custom Passes

Custom pass를 정의하는 방법을 확인하기 위해 relu를 gelu로 변환하는 예제를 수행해보자

  • (역자주) gelu : transform 등 최신 모델의 활성함수로 GELU 함수는 표준 가우신안 누적 분포 함수 인 xΦ(x)로 정의
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# ReluRewriter클래스는 PyExprMutator클래스를 상속받아 visit_call_을 오버라이딩
from tvm.relax.expr_functor import PyExprMutator, mutator

@mutator
class ReluRewriter(PyExprMutator):
    def __init__(self, mod):
        super().__init__(mod)

    def visit_call_(self, call: relax.Call) -> relax.Expr:
        # visit the relax.Call expr, and only handle the case when op is relax.nn.relu
        if call.op.name == "relax.nn.relu":
            return relax.op.nn.gelu(call.args[0])

        return super().visit_call_(call)

위의 mutator를 적용한 pass를 이용해 trasnformation

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu:  # pylint: disable=too-few-public-methods
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        """IRModule-level transformation"""
        rewriter = ReluRewriter(mod)
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                func = rewriter.visit_expr(func)
                rewriter.builder_.update_func(g_var, func)
        return rewriter.builder_.get()

mod = ReluToGelu()(origin_mod)
mod.show()

결과를 확인해보면 relax.nn.relu operator가 relax.nn.gelu 로 변경된 겻을 확인할 수 있다. 자세한 내용은 relax.expr_functor.PyExprMutator 참고

Licensed under CC BY-NC-SA 4.0
comments powered by Disqus
Built with Hugo
Theme Stack designed by Jimmy