Keras 에서 Jax, Flax 로 코드를 옮겨본 소감

성능이 더 개선된다는 이유로, 그리고 기술적인 호기심으로 Keras 코드를 Jax로 옮기고 있다. 그러면서 느낀 점들을 공유해서 같은 migration을 하고 싶은 경우에 참고가 되었으면 한다.

Jax 의 첫인상은 배우기 어렵단 것이었다. 특히 Sharpbits 에 그 내용이 잘 정리되어있는데 대부분은 데이터가 immutable 하다는 것으로 요약된다. 이것이 처음에는 매우 당혹스럽다가 서서히 익숙해졌다. 그러나 몇가지는 생각지도 못한 문제가 있었다. 예를들어 list 는 hashable 하지 않은데 따라서 @jax.jit 함수의 인자로 사용할 수 없어 대신 tuple 을 이용해야 한다는 것도 메뉴얼만 읽었을때는 알지 못했었다. 시간이 지나면 Jax 의 방식은 이해가 되지만 그것을 실제로 적용하는 과정에 시행착오가 많았다.

두번째는 당장 tf.data 패키지에서 벗어나야한다는 점이었다. 이론적으로는 jax, flax, tf 를 모두 설치할 수 있어야하지만 실제로는 tf는 numpy 2.14를 필요로 하고 jax, flax는 numpy 2.15를 사용하고 있어 함께 설치할 수 없었다. 이것은 기존 generator가 그렇게 복잡하지는 않았기에 너무 어려운 migration은 아니었다. 데이터가 조금 복잡할때 tensor spec을 작성하는 고통도 더이상 없다. 그러나 prefetch() 등을 구현하는 것은 만만치 않을 것이다.

세번째는 당연하게 제공되는 각종 loss, metric 함수들이 flax에는 없단 점이다. 예를들어 categorical focal cross entropy는 당연히 없어 구현해야했다. 심지어 cross entropy마저 다들 만들어 쓰고 있단 점은 example에서도 쉽게 확인된다. 라이브러리는 존재하지만 굳이 갖다 쓰지는 않았다. (loss 구현은 ChatGPT 가 도와주었다.)

Flax는 그 구현의 추상화가 굉장히 깊다. 문서도 부족해서 코드를 봐야하는데 코드는 너무나 복잡했다. 예를들어 각 nn.Module 이 hierarchy를 가진 scope로 구성된다는 것은 어느 문서에도 없고 시행착오와 코드 읽기를 통해 알아내야했다. 하지만 각잡고 함수들을 분석하기엔 너무 벅차다. 나는 migration만 생각했던건데. 그럼에도 결국은 코드를 보는 것이 진짜 해결책인 경우가 많았다. 예를들어 2개 이상의 dropout을 쓸때 각 레이어마다 dropout rng 를 줘야하는지 등은 코드를 읽기전까지 알기 어렵다. (답은 아니다이다. dropout은 rng는 하나만 주면 rng 를 split 할때 모듈 hash 값을 고려한다.)

당연하게 생각하는 쉬운 구현도 없다. 예를들어 Keras의 LSTM 은 없고 LSTMCell 은 있다. 즉, 주어진 시퀀스를 순회하며 LSTMCell 을 적용하는게 구현자의 책임이란 뜻이다. 이것도 for 문을 써야하는지로 알았다가 뒤늦게 nn.scan 을 알게되었다. 예제는 너무 복잡해 이런 것들을 쉽게 찾기 어려웠다. 당연한 것 중 없는 것이 또 있다. 예를들어 각 step 마다 progress bar 를 증가시키는 것도 구현자 책임이다.

Stateful LSTM 은 악몽이나 마찬가지였다. 하나하나 hidden, state 를 보관해야한다. 만들고 나서, 알고나서는 쉽다. 처음하기는 만만치 않다. 너무 괴로웠던 탓에 코드를 github에 정리해서 공개해두었다.

그런데 왜 이 고생을 하며 Jax, Flax 를 썼을까. 서두에 적은대로 속도와 재미 때문이었는데 아직 코드를 모두 옮기지 못해서 속도는 확인을 하지 못했다. 하지만 재미는 있었고 Keras 처럼 model.fit() 에서 코드가 죽을때의 답답함도 덜했다. (그러나 Jax 도 만만치 않다. jax.disable_jit() 을 알기전까진 디버깅이 너무나 어려웠다.) 코드가 이렇게 돌아야하는데 그게 맞나 싶은 부분도 덜했다. Flax 의 추상화는 굉장히 깊어 이해하기 어렵지만 각 레이어의 동작은 생각대로다.

아직 코드를 다 옮기지는 못했지만 코드가 정상 동작한다면 다시 keras 로 돌아갈 생각은 아직 없다. 어차피 keras도 다음 버전부터는 tf 의 일부가 아닌 모든 백엔드에서 돌아가는 독립적인 코드가 되서 어떤 종류든 migration 은 예상된 일이었다. flax 는 keras 와 달리 복잡하지 않아 무엇이든 내 마음대로 할 수 있고 세세하게 모델을 수정할 수 있다. 한번 코드를 e2e로 작성해두면 대부분은 템플릿과 마찬가지일것이므로 model apply, loss, training loop을 크게 수정할 일도 없다.

p.s. 그러나 이 작업은 취소하게 되었다. M1 맥북에서 Flax가 GPU를 지원하지 않는다는걸 뒤늦게 알게 되었다.

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *