-
[yongggg's] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Review)Machine & Deep Learning 2024. 7. 26. 09:56
Mamba는 transformer의 아키텍처의 문제점을 보완하고 이 모델의 성능을 능가할 수도 있다라는 새로운 아키텍처 입니다.
공식의 구현과 모델의 Checkpoint는 다음 github에서 확인할 수 있습니다.
https://github.com/state-spaces/mamba
GitHub - state-spaces/mamba: Mamba SSM architecture
Mamba SSM architecture. Contribute to state-spaces/mamba development by creating an account on GitHub.
github.com
그럼 바로 설명하겠습니다!
1. Transformer의 문제점
Transformer는 위의 그림처럼 모든 text 입력을 token 단위의 Sequence를 input으로 받는다.
그래서 Transformer는 입력받는 것이 어떤 것이든 간에, 시퀀스의 이전 토큰들을 돌아볼 수 있어 그 표현을 도출할 수 있다. (self-attention)
Transformer의 핵심 구성 요소
transformer는 text를 representate하기 위한 encoder blocks과 text를 생성하기 위한 decoder blocks의 두 구조로 구성되어 있다. 특히, Generative Pre-trained Transformer(GPT)는 transformer의 decoder blocks만을 사용하여 입력과 출력을 만들어낸다.
작동 원리
하나의 Decoder block은 Masked Self-Attention과 Feed-forward neural network로 구성되어 있다.
Self-Attention은 이 모델들이 잘 작동할 수 있게하는 이유 중 하나이다. 이 모듈은 전체 Sequence에 대한 압축되지 않은 관점을 제공하며, 훈련 속도를 빠르게 한다.
이 모듈은 각 token을 그 이전에 나온 모든 tokens과 비교하는 행렬을 만든다. 행렬의 가중치는 각 자리의 Token-pair가 얼마나 관련이 있는지에 대해 결정이 된다.
훈련하는 과정에서 이 행렬은 한 번에 생성될 수 있다. ("My"와 "name"사이의 attention weights이 먼저 계산될 필요가 없이 "name"과 "is" 사이의 attention weights을 계산할 수 있다.) 이렇게 병렬처리를 가능하게 하며, Training 속도를 크게 높인다.
Inference Phase에서의 문제
위처럼 학습하는 단계에서는 빠르게 학습이 가능하지만, 추론 시에서는 문제가 발생한다.
다음 토큰을 생성할 때, 이미 일부 토큰을 생성했더라도 전체 Sequence에 대한 Attention을 다시 계산해야한다.
길이가 $L$인 Sequence에 대한 token을 생성하는 데에는 대략 $L^{2}$의 계산이 필요하며, 이에 따라 Sequence 길이가 증가하면, 당연히 계산 비용이 많이 든다.
전체 Sequence를 다시 계산해야 하는 것은 transformer architecture의 주요 병목 현상의 원인이 된다.
이러한 느린 추론의 문제를 RNN(Recuurent Neural Networks)는 어떻게 해결했는지 살펴보자.
RNN
순환신경망은 Sequence 기반의 Networks이다. 각 time step에서 두 가지 입력을 받는데, 현재 time step $t$의 입력과 이전 time step $t-1$의 hidden state로 다음 hidden state를 생성하고 출력을 예측한다.
RNN은 이전 step의 정보를 다음 step으로 전달할 수 있는 looping 메커니즘을 갖고 있다. 출력을 생설 할 때, RNN은 이전 step의 hidden state와 현재의 input만 고려해야 한다. 이는 transformer에서 모든 이전 hidden state를 다시 계산하는 것을 방지한다. 즉, RNN은 Sequence 길이와 선형적으로 비례하는 빠른 추론을 할 수 있다. 또한 이론적으로는 무한한 context length를 가진다.
이해를 돕기 위해, 위의 transformer에서 사용한 입력 text에 RNN을 적용해본다.
각 hidden state는 모든 이전 step의 hidden state가 압축된 representation이다. 하지만, Sequence 길이가 길어질 수록 뒤의 생성되는 단어는 매우 앞의 step의 정보를 잊어버리는 경향이 있다. (위의 그림에선 "Maarten"을 생성할 때, "Hello"에 대한 정보가 없어짐.)
RNN의 이러한 Sequence 특성은 또 다른 문제를 만든다. 훈련이 병렬로 수행될 수 없으며, 순차적으로 각 Step을 거쳐야 한다.
RNN의 문제는 transformer의 문제와 정 반대이다. 추론은 빠르지만 training phase에서 병렬로 처리할 수 없다.
transformer처럼 훈련은 병렬화할 수 있으며, Sequence 길이에 비례한 선형적 추론 시간을 갖는 architecture가 있을까?
바로 Mamba가 제공한다. 이 architecture를 알아보기 전에 먼저 State Space Models(상태 공간 모델)의 세계를 살펴보자.
State Space Models(SSM; 상태 공간 모델)
State Space Model(SSM; 상태 공간 모델)은 transformer와 RNN과 마찬가지로 정보의 Sequence를 처리한다.
이는 Text뿐만 아니라 신호와 같은 것들도 포함한다.
이 섹션에서는 SSM의 기초와 text 데이터와의 관련성을 살펴본다.
State Space(상태 공간)
State Space(상태 공간)이란 시스템을 완전히 설명하는 최소한의 변수를 포함하며, 시스템의 가능한 상태들을 정의함으로써 문제를 수학적으로 표현하는 방법이다.
예를 들어, 미로를 탐색한다고 하자. "State Space"은 모든 가능한 위치(상태)의 지도라고 할 수 있다. 각 지점은 미로에서의 고유한 위치를 나타내며, 출구까지 얼마나 떨어져 있는지와 같은 구체적인 세부 사항을 포함한다. "State Space Representation"은 이런 지도에 대한 간단한 설명이다. 여기에는 현재 위치(현재 상태), 다음에 갈 수 있는 곳(가능한 미래 상태), 다음 상태로 이동하는 변화(오른쪽 혹은 왼쪽으로 가는 것)가 표시된다.
State Space Model은 방정식과 행렬을 사용하여 이러한 행동을 추적하지만, 그것은 단순히 현재 위치를 추적하고 어디로 갈 수 있는지, 그리고 어떻게 거기에 도달할 수 있는지를 추적하는 방법일 뿐이다.
상태를 설명하는 변수들, 예제에서는 X와 Y좌표와 출구까지의 거리는 "State Vector"로 표현될 수 있다.
이와 같은 State Vector는 언어 모델에서 Embedding 혹은 hidden state vector도 입력 Sequence의 "State"를 설명하는 데 자주 사용되기 때문이다. 예를 들어, 현재 위치(State) vector는 다음과 같은 형상을 띌 수 있다.
신경망 측면에서는 시스템의 "State"는 일반적으로 그것의 Hidden State이며, 대규모 언어 모델의 맥락에서는 새로운 Token을 생성하는 중요한 representation 중 하나이다.
State Space Model
SSM은 이러한 State representation을 설명하고 입력에 따라 다음 state가 어떻게 될 지 예측하는 모델이다.
전통적으로, 시간 $t$에서 SSM은:
- 입렵 Sequence $x(t)$를 mapping 한다. (예: 미로의 왼쪽으로, 아래로 움직임)
- hidden state representation $h(t)$로 (예: 출구까지의 거리와 $x, y$ 좌표)
- 예측된 출력 Sequence $y(t)$를 도출한다. (예: 출구에 더 빨리 도달하기 위해 다시 왼쪽으로 이동)
하지만, 한 번에 왼쪽으로 이동하는 것과 같은 이산 sequence를 사용하는 대신, 연속 sequence를 입력으로 받아 출력 sequence를 예측한다.
SSM은 동적 시스템, 예를 들어 3D 공간에서 움직이는 객체와 같은 것들이 시간 $t$에서의 state를 통해 두 방정식을 거쳐 예측될 수 있다고 가정한다.
이 방정식을 풀면, 관측된 데이터(입력 Sequence와 이전 State)를 기반으로 시스템의 상태를 예측하는 통계 원칙을 밝혀낼 수 있다고 가정한다. 그 목적은 입력에서 출력 sequence로 이동할 수 있도록 이 state representation $h(t)$를 찾는 것이다. 이 두 방정식은 State Space Model의 핵심이다. (이 두 방정식을 직관적으로 보기 위해 색상으로 표시하였다.)
$h(t)$는 주어진 시간 $t$에서의 hidden state representation을 나타내며, $x(t)$는 어떠한 입력을 나타낸다. output 방정식은 state가 어떻게 출력으로 변환되는지 (행렬 C를 통해)와 입력이 출력에 어떻게 영향을 주는지 (행렬 D)를 통해 설명한다.
(단, 행렬 A, B, C, D는 일반적으로 trainable weights로 취급한다.) 이 두 방정식을 시각화하면, 다음과 같은 architecture가 나온다.
이러한 행렬들이 학습 과정에 어떻게 영향을 미치는지 단계별로 살펴보겠다.
어떤 입력 신호 $x(t)$가 있다고 가정하자. 이 신호는 먼저 시스템에 입력이 어떻게 영향을 미치는지 설명하는 행렬 B에 곱해진다.
업데이트된 state(신경망의 hidden state와 유사)는 환경의 핵심 "knowledge"를 포함하는 hidden state이다. 우리는 state에 행렬 A를 곱하여 시스템 내부 state가 어떻게 연결되어 있는지를 나타낸다. 이는 시스템의 기본 역학을 나타낸다.
위의 그림을 보면 알겠지만, 행렬 A는 state representation을 만들기 전에 적용되며, state representation이 업데이트된 이후에 업데이트가 된다. 이 다음 우리는 행렬 C를 사용하여 state를 output으로 변환하는 방법을 설명한다.
마지막으로, 우리는 행렬 D를 사용하여 input에서 output으로 직접 신호를 제공할 수 있다. 이는 skip-connection이라고도 한다.
행렬 D가 skip-connection과 유사하기 때문에, SSM은 종종 skip-connection이 없는 부분 처럼 간주된다.
이제 skip-connection을 제외하고 SSM의 핵심 부분인 행렬 A, B, C에 집중하자.
원래의 방정식을 업데이트하고 각 행렬의 목적을 나타낼 수 있다.
이 두 방정식은 함꼐 관측된 데이터로부터 시스템의 state를 예측하려고 한다. 입력이 연속적으로 예상되기 때문에, SSM의 주요 representation은 continuous-time representation(연속 시간 표현)이다.
연속 신호에서 이산 신호로
연속 신호를 가지고 있을 때, state representation $h(t)$를 분석하는 것은 challenging하다. 또한 우리가 일반적으로 이산적인 input(예: text sequence)를 가지고 있기 때문에, 모델을 이산화 하고 싶다.
이를 위해, 우리는 Zero-Order hold 기술을 사용한다. 이는 다음과 같이 작동한다. 먼저, 우리가 이산 신호를 받을 때마다, 새로운 이산 신호를 받을 때까지 그 값을 유지한다. 이 과정은 SSM이 사용할 수 있는 연속 신호를 생성한다.
값을 유지하는 시간의 길이는 새로운 trainable parameter인 step size $\vartriangle$로 표현한다. 이것은 입력의 해상도를 나타낸다.
이제 입력을 위한 연속 신호를 갖고있음으로 연속 출력을 생성하고 입력의 시간 단계에 따라 값들을 sampling할 수 있다.
이 샘플링된 값들이 우리의 이산화된 출력이다. 수학적으로 Zero-Order hold를 다음과 같이 정의할 수 있다.
또한 이들은 연속 SSM에서 이산 SSM으로 변환할 수 있게 해주며, 이는 함수에서 함수 $( x(t) \rightarrow y(t) )$대신 시퀀스에서 시퀀스 $( x_{k} \rightarrow y_{k} )$로 표현한다.
여기서 이제 행렬 A와 B는 모델의 이산화된 매개변수를 나타낸다. 우리는 연속 SSM과 이산 SSM을 참조할 때 더 명확하게 하기 위해, $t$대신 $k$를 사용하여 이산화된 time step을 나타낸다.
- 훈련 중에는 연속 형태의 행렬 A를 저장하지만 이산화된 버전은 저장하지 않는다. 훈련 중에는 연속 representation이 이산화 된다. 이제 이산 representation의 공식을 갖고 있으니, 모델을 어떻게 계산할 수 있는지 탐구한다.
이산화 작업은 SSM에서 가장 중요한 작업이다. SSM 아키텍쳐의 모든 효율성은 이 단계라고해도 과언이 아니다. SSM의 continuous representation은 recurrent representation 또는 convolution representation의 Discrete 작업으로 넘어갈 수 있다.
재귀
위의 이산화된 SSM은 연속 신호 대신 특정 시간 단계에서 문제를 공식화할 수 있게 한다. RNN에서 봤던 것처럼, 여기에서도 재귀적인 접근 방식은 매우 유용하다.
우리가 연속 신호 대신 이산 Time step을 고려한다고 하면, 문제를 time step과 함께 재공식화 할 수 있다.
각 시간 단계에서 현재 입력 $Bx^{k}$가 이전 state $Ah_{k-1}$에 어떻게 영향을 미치는지를 계산한 다음, 예측된 출력 $Ch_{k}$를 계산한다.
이 표현은 이전의 RNN에서 봤던 것처럼 같은 방식으로 접근할 수 있다.
우리는 이를 다음과 같이 펼칠 수 있다.
우리는 이산화된 버전을 RNN의 기본 방법론을 사용할 수 있음을 알 수있다. 이 기술은 RNN의 빠른 추론과 느린 훈련이라는 장단점을 닮았다.
컨볼루션
우리가 SSM에 사용할 수 있는 또 다른 표현은 컨볼루션이다.
여기서는 이미지가 아닌 text를 다루기 때문에 1차원 관점으로 보겠다.
우리는 이 CNN의 "filter"를 나타내기 위해 사용할 Kernel을 SSM 공식에서 도출한다.
이 Kernel이 실제로 어떻게 작동하는지 보자. Convolution처럼 SS커널을 사용하여 각 token set을 통과하고 출력을 계산할 수 있다.
이것은 또한 padding이 출력에 미칠 수 있는 영향을 보여준다. 이해하기 쉽도록 시각화하기 위해 padding 순서를 변경했지만, 원래의 경우 대부분 문장 끝에 패딩을 적용한다.
다음 단계로, kernel은 다음 단계 계산을 수행하기 위해 이동한다.
마지막 step에서, 우리는 커널의 전체 효과를 볼 수 있다.
SSM을 컨볼루션으로 표현하는 주요한 이점은 병렬적인 훈련이다. 하지만, 고정된 kernel의 크기 때문에, 컨볼루션의 추론은 RNN처럼 빠르고 무한하지 않다.
Convolution for LTI-SSM
Recurrent 모델은 이전 time-step 의 결과를 구해야 다음의 state를 계산할 수 있고, 이는 결국 sequence 길이가 길어질수록 학습이 상당히 느려지게 된다. 반면에 LSSL은 activation을 제거하면서 각 time-step 에 들어올 입력을 미리 알고 있다는 전제 하에, 전체 sequence에 대한 state를 한번에 구해낼 수 있다. 이는 근본적으로 이 모델이 Linear Time-Invariant (LTI) SSM, 즉 time-step t에 관계없이 항상 동일한 weight를 적용하는 형태를 전제했기 때문이다.
학습시에는 모델은 전체 입력을 미리 알고있으므로, 이 convolution 연산을 수행할 수 있다. 물론 완벽히 recurrent 방식처럼 무한한 길이에 대해 수행할 수 있는 것은 아니고, 대신 학습을 위해 모델을 fixed window size 의 convolution network로 근사한다고 이해해야 한다.
예를 들어 window size가 3인 경우, 다음과 같이 크기 3짜리 convolution kernel $K_{3}$을 만들 수 있는데,
$$ K_{3} = [W_{y} \bar{H}^{2} W_{x}, W_{y} \bar{H} W_{x}, W_{y}W_{x} ] $$
이로부터 y3은 다음과 같이 계산될 수 있다.
$$ y_{3} = K_{3} \cdot [x_{3}, x_{2}, x_{1}] = K_{3} X_{:3 \cdot} $$
학습 시에는 모든 입력 $x_{1}, x_{2}, x_{3}$을 이미 알고있기 때문에, 이렇게 전체 크기의 $\bar{K}$를 미리 구해놓으면 병렬연산을 통해 모든 $y_{t}$에 대한 값을 한 번에 계산할 수 있다.
물론 여기에는 이 커널 벡터를 한번에 계산한다는 것이 전제가 된다. $ \bar{H} $가 Hippo로 고정된 값일 경우, 이 값을 미리 계산해놓으면 되지만, S4에서는 $\bar{H} $도 결국 학습되기 때문에 sequence의 길이 $n$만큼의 $\bar{H}^{n}$를 매번 계산해야 하는 문제가 있고, 실제 학습시에는 이 논문에서 새로 제시한 알고리즘을 통해 근사시킨 $\bar{K}$를 사용한다.
를 고정시킨 경우와 학습시킨 경우의 차이는 저자의 LSSL 논문에서 자세히 비교하고 있는데, 결국 이 근사 알고리즘은 H‾가 Hippo Matrix 라는 것을 전제로 유도되었기 때문에 학습을 통해 업데이트된 행렬에 대한 근사와 일치한다는 보장은 없다. 이 커널 근사 방식은 저자들의 매 연구마다 업데이트되고 있다.
세가지 표현
Continuous-time, Recurrent, Convolutional 이 표현 방법들은 모두 다른 장단점을 갖고 있다.
위의 장점을 이용하고자 Recurrent SSM을 통한 효율적인 추론과 Convolutional SSM을 통한 병렬화된 훈련을 할 수 있게 됐다.
이 표현들을 사용하여 특정 임무에 따라 표현을 선택하는 깔끔한 트릭을 사용할 수 있다. 훈련 중에는 병렬화가 가능한 Convolutional 표현을 사용하고, 추론 중에는 효율적인 Recurrent 표현을 사용한다.
이 모델은 Linear State-Space Layer (LSSL; 선형 상태 공간 레이어)로 언급된다.
이러한 표현들은 Linear Time Invariance (LTI; 선형 시간 불변성)이라는 중요한 특성을 공유한다. LTI는 SSM의 매개변수인 A, B, C가 모든 시간 단계에 대해 고정되어 있다고 명시한다. 즉, 행렬 A, B, C는 SSM이 생성하는 모든 토큰에 대해 동일하다. 즉, SSM에 어떤 시퀀스를 제공하던 간에 A, B, C의 값은 동일하다. 이 뜻은 내용 인식(content-awareness)이 없는 정적 표현을 가지고 있다는 것이다.
Mamba가 이 문제를 어떻게 해결하는지 살펴보기 전에, 마지막 퍼즐 조각인 행렬 A에 대해 탐구해 보자.
행렬 A의 중요성
행렬 A는 아마도 SSM 공식의 가장 중요한 측면 중 하나일 것이다. 위의 "재귀" 부분에서 본 것처럼, 이 행렬은 이전 상태에 대한 정보를 포착하여 새로운 상태를 구축한다.
본질적으로, 행렬 A는 hidden state를 생성한다.
행렬 A를 생성하는 것은 몇 개의 이전 토큰만 기억하는 것과 지금까지 본 모든 토큰을 포착하는 것 사이의 차이를 만들 수 있다. 특히 '재귀적'이라는 맥락에서는 이전 상태만을 고려하기 때문이다.
그렇다면 어떻게 큰 메모리(맥락 크기)를 유지하는 방식으로 행렬 A를 생성할 수 있을까?
Hungry Hungry Hippo 또는 HiPPO(High-order Polynomial Projection Operators)를 사용한다. HiPPO는 지금까지 본 모든 입력 신호를 계수의 벡터로 압축하려고 시도한다.
이는 행렬 A를 사용하여 최근 토큰을 잘 포착하고 오래된 토큰을 감소시키는 state representation을 구축한다. 해당 공식은 다음과 같다.
정사각형 행렬 A를 가정한다면, 다음을 얻을 수 있다.
HiPPO를 사용하여 행렬 A를 구축하는 것은 무작위 행렬로 초기화하는 것보다 훨씬 나은 것으로 나타났다. 결과적으론, 오래된 신호(초기 토큰)보다 새로운 신호(최근 토큰)를 더 정확하게 재구성한다.
HiPPO 매트릭스의 아이디어는 이것이 history를 기억하는 hidden state를 생성한다는 것이다. 수학적으로 이는 Legendre polynomial을 추적함으로써 이전 history의 모든 것을 근사화할 수 있게한다.
HiPPO는 장거리 의존성을 처리하기 위해 위에서 본 Recurrent 및 Convolutional representation에 적용되었습니다. 결과적으로 sequence를 효율적으로 처리할 수 있는 SSM 클래스인 Structured State Space for Sequences(S4) https://arxiv.org/abs/2111.00396
Efficiently Modeling Long Sequences with Structured State Spaces
A central goal of sequence modeling is designing a single principled model that can address sequence data across a range of modalities and tasks, particularly on long-range dependencies. Although conventional models including RNNs, CNNs, and Transformers h
arxiv.org
가 등장했다. 이 클래스는 크게 세 부분으로 구성된다.
- State-Space Model (SSM; 상태 공간 모델)
- 장거리 의존성을 처리하기 위한 HiPPO
- Recurrent 및 Convolutional 표현을 생성하기 위한 이산화
이 SSM 클래스는 선택한 (Recurrent 및 Convolutional) 표현에 따라 여러 가지 이점을 갖고 있으며, 텍스트의 긴 시퀀스를 처리하고 HiPPO 매트릭스를 기반으로 메모리를 효율적으로 저장할 수 있다. (참고: HiPPO 매트릭스를 계산하고 S4 모델을 직접 구축하는 데 필요한 기술적 세부 사항에 대해 더 깊이 파고 들고자 하면, https://srush.github.io/annotated-s4/ 를 참고하라.)
Mamba - 선택적(Selective) SSM
마침내 우리는 Mamba가 특별한 이유를 이해하는 데 필요한 모든 기초에 대해 배웠다. State-Space Model을 text sequence를 모델링하는 데 사용될 수 있지만, 단점들 또한 존재한다.
먼저, Mamba의 contribution은 다음과 같다.
- Mamba가 (비)관련 정보를 필터링할 수 있게 하는 선택적 스캔 알고리즘.
- 병렬 스캔, 커널 융합 및 재계산을 통해 (중간) 결과의 효율적인 저장을 가능하게 하는 하드웨어 인식 알고리즘.
이 알고리즘과 함께 선택적 SSM 또는 S6 모델을 생성하는데, 이는 self-attention과 같이 Mamba 블록을 생성하는 데 사용될 수 있다.
두 가지 주요 기여를 탐구하기 전에, 왜 이들이 필요한지 살펴보자.
문제 정의
State Space Model, 심지어 S4(Structured State Space Model)조차도 언어 모델링 및 생성에서 중요한 특정 작업에서 성능이 떨어진다. (특정 입력에 집중하거나 무시하는 능력)
이 능력은 Selective Copying(선택적 복사) 및 induction heads(유도 헤드)라는 두 가지 합성 작업으로 설명할 수 있다.
선택적 복사 작업에서, SSM의 목표는 입력의 일부를 복사하여 순서대로 출력하는 것이다.
그러나 (Recurrent/Convolutional) SSM은 이 작업에서 성능이 떨어진다. 그 이유는 Linear Time Invariance(LTI; 선형 시간 불변성) 때문이다. 앞서 본 것처럼, 행렬 A, B, C는 SSM이 생성하는 모든 토큰에 대해 동일하다. 결과적으로, SSM은 content-awareness inference(내용 인식 추론)을 수행할 수 없다. 고정된 A, B, C 행렬 때문에 각 토큰을 동등하게 취급하기 때문이다. 우리는 SSM이 input(prompt)에 대해 추론하기를 원한다.
SSM 성능이 떨어지는 두 번째 작업은 induction heads(유도 헤드)로, 입력에서 발견된 패턴을 재현하는 것이 목표이다.
바로 위의 예제에서, 우리는 기본적으로 한 번의 prompting을 수행한다. 여기서 우린 모델에게 "Q:" 후에 "A:" 응답을 제공하도록 학습하려고 한다. 그러나 SSM이 시간 불변성을 가지고 있기 때문에, 이전 토큰 중 어떤 것을 기억할지 선택할 수 없다.
이를 행렬 B에 집중해서 다음 그림을 보자. 입력 $x$가 무엇이든 관계없이 행렬 B는 정확히 동일하게 유지되므로 $x$와 독립이다.
마찬가지로 A와 C도 입력과 관계없이 고정되어 있다. 이는 지금까지 설명한 SSM의 정적인 특성을 보여준다.
비교적으로 이러한 작업들은 transformer가 수행하기엔 상대적으로 쉽다. transformer는 입력 sequence에 따라 동적으로 attention을 변경할 수 있다. 그들은 sequence의 다른 부분을 선택적으로 "forcus"하거나 "attention"할 수 있다.
SSM이 이러한 작업에서 성능이 떨어지는 것은 시간 불변성 SSM의 근본적인 문제, 즉 A, B, C 행렬의 정적인 특성으로 인한 content-awareness(내용 인식) 문제를 시사한다.
정보를 선택적으로 유지
SSM의 재귀 방식 표현은 전체 history를 압축하는 매우 효율적인 small state를 생성한다. 그러나 attention 행렬을 통해 history를 전혀 압축하지 않는 transformer 모델과 비교하면 정보 손실이 훨씬 크다.
Mamba는 두 가지의 장점, 즉 transformer의 state를 가지며, 이 state는 강력하고 small(작은) state를 갖는다.
바로 위의 그림에서 보이는 것과 같이, Mamba는 데이터를 state에 선택적으로 압축함으로써 두 가지의 장점을 달성한다. 예를 들어, 입력 문장이 들어 왔을 때, 종종 중요한 의미를 가지지 않는 정보(ex: stop words)를 제외하여 선택적으로 압축하는 것이다.
정보를 선택적으로 압축하기 위해서는 입력에 따라 매개변수가 달라져야 한다. 이를 위해, 훈련 중 SSM의 입력과 출력의 차원을 먼저 살펴 보자.
Structured State Space Model (S4)에서, 행렬 A, B, C는 입력값과 독립적이다. 이 행렬의 차원 N과 D는 정적이며 변하지 않는다.
대신 Mamba는 행렬 B와 C 심지어 step size $ \vartriangle$ 를 입력에 의존하도록 하여, 입력의 sequence 길이와 batch 크기를 포함한다.
이것은 이제 모든 입력 토큰에 대해 서로 다른 B와 C 행렬을 가지고 있다는 것을 의미한다. 즉, content-awareness(내용 인식) 문제를 해결한다는 것이다. (참고: 행렬 A는 동일하게 유지된다. 왜냐하면, 우린 state 자체가 정적으로 유지되기를 원하지만, 이것이 영향을 받는 방식(B와 C를 통해)은 동적으로 유지되기를 원하기 때문이다.
게다가. 이들은 입력에 의존하여 hidden state에서 무엇을 유지하고 무엇을 무시할지 selective하게 selection한다.
더 작은 스텝 크기 $ \vartriangle $ 는 특정 단어를 무시하는 대신 이전 맥락을 사용하는 것에 더욱 초점을 맞춘다. 반면, 큰 step 크기 $ \vartriangle $은 맥락보다 입력 단어에 더 집중한다.
스캔 작업
이제 이런 행렬이 동적이기 때문에, 고정된 커널을 가정하는 convolution을 사용하여 계산할 수 없다. 우리는 convolution의 병렬성을 잃고 Recurrent 표현만을 사용할 수 있다.
$\rightarrow$ 병렬화를 가능하게 하기 위해, 재귀적으로 출력을 계산하는 방법을 살펴보자.
각 state는 이전 state(행렬 A에 의해 곱해진)와 현재 입력(B에 곱해진)의 합이다. 이것을 스캔 작업이라고 하며, For loop를 사용하여 쉽게 계산할 수 있다.
병렬화와는 대조적으로, 이전 state를 가지고 있어야만 각 state를 계산할 수 있으므로 병렬화는 불가능 해보이나, Mamba는 병렬 스캔 알고리즘(https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda)을 통해 이를 가능토록 한다.
이 알고리즘은 연산 순서가 중요하지 않다고 가정함으로써 수행된다. 이 뜻의 keypoint는 Sequence를 부분적으로 계산하고 반복적으로 결합할 수 있다는 것이다.
이와 함께, 동적인 행렬 B와 C, 그리고 병렬 스캔 알고리즘은 Recurrent 표현을 사용하는 동적이고 빠른 특성의 Selective Scan 알고리즘을 생성한다.
하드웨어 인식 알고리즘
최근 GPU의 단점 중 하나는 작지만 매우 효율적인 SRAM과 크지만 약간 덜 효율적인 DRAM 사이의 전송(IO) 속도가 제한되어 있다는 것이다. SRAM과 DRAM 사이에 정보를 자주 복사한다면, 병목 현상이 일어날 수 밖에 없다.
Mamba는 Flash Attention과 같이 DRAM에서 SRAM으로, 또 그 반대로 이동하는 횟수를 제한하고자 했다. 이 것은 커널 융합을 통해 모델이 중간 결과를 쓰는 것을 방지하며, 계산을 완료할 때까지 지속적으로 계산을 수행한다.
Mamba의 기본 아키텍처를 다음과 같이 시각화함으로써 DRAM과 SRAM 할당의 구체적인 instance를 볼 수 있다.
여기서, 다음의 작업은 하나의 커널로 융합된다.
- step 크기 $ \vartriangle $와 함께 이산화 단계
- selective scan 알고리즘
- C와의 곱셈
하드웨어 인식 알고리즘의 마지막 부분은 재계산이다.
중간 state들은 저장되지 않지만, Gradients를 계산하기 위해, backward pass에서 필요하다. 따라서, backward pass 동안 이러한 중간 states를 재계산한다. 이것은 비효율적으로 보일 수 있지만, 상대적으로 느린 DRAM에서 모든 중간 state를 읽는 것보다 훨씬 비용이 들지 않는다.
이제 최종 아키텍처를 아래의 그림과 같이 표현함으로써 묘사한 모든 구성요소를 나타낼 수 있다.
이 아키텍처는 종종 Selective SSM 혹은 S6 모델로 불리며, 이는 본질적으로 Selective Scan 알고리즘으로 계산된 S4 모델이다.
Mamba Block
우리가 지금까지 탐구한 Selective SSM은 Decoder Block에서 Self-attention을 나타내는 것과 같은 방식으로 블록을 구현할 수 있다.
Decoder와 마찬가지로 우리는 여러 Mamba 블록을 쌓을 수 있으며, 그 출력을 다음 Mamba 블록의 입력으로 사용할 수 있다.
이 모델은 Input embedding을 확장하기 위한 Linear projection으로 시작한다. 그 다음, Selective SSM을 적용하기 전에, 독립적인 token 계산을 방지하기 위한 convolution이 사용된다.
Selective SSM은 다음과 같은 특성을 가진다.
- 이산화를 통해 생성된 Recurrent SSM
- 장거리 의존성을 포착하기 위한 HiPPO intialization matrix A
- 정보를 선택적으로 압축하기 위한 Selective Scan 알고리즘
- 계산 속도를 높이기 위한 하드웨어 인식 알고리즘
코드를 구현한 것을 살펴본 뒤 End-to-End 예시를 탐구해보면, 이 아키텍처를 조금 더 확장할 수 있다.
출력 token을 선택하기 위해, 정규화 Layer와 Softmax와 같은 몇 가지 수정 사항을 볼 수 있다. 이 모든 것을 함께 구성하면, 빠른 추론과 훈련뿐만이 아니라 무한히 긴 context도 얻을 수 있다.
이 아키텍처를 사용하여, 저자들은 동일한 크기의 transformer 모델의 성능과 일치하거나 더 좋음을 발견했다.
Reference
- https://srush.github.io/annotated-s4/
- https://www.youtube.com/watch?v=ouF-H35atOY
- https://huggingface.co/state-spaces
- https://github.com/state-spaces/mamba
- https://hazyresearch.stanford.edu/blog/2022-01-14-s4-1
- https://hazyresearch.stanford.edu/blog/2022-01-14-s4-2
- https://hazyresearch.stanford.edu/blog/2022-01-14-s4-3
- https://jameschen.io/jekyll/update/2024/02/12/mamba.html
- https://arxiv.org/abs/2312.00752
- https://junia3.github.io/blog/lssl (LSSL/LMU/HiPPO)
-
'Machine & Deep Learning' 카테고리의 다른 글