H3の動作確認でエラー: RuntimeError: "reciprocal_cuda" not implemented for 'ComplexHalf'

ブログ機械学習

概要

最近,耳目を集めている言語モデルにHyenaと呼ばれるものがあります.このHyenaのベースとなっているモデルにHungry Hungry Hippos (通称H3)があるのですが,弊社でもこのH3を活用する流れがあります.

弊社の機械学習チームにて,H3論文に関する動作確認を進めている中で,特定の条件下で計算上の問題に直面しました.本記事では,そのエラーの原因と背景,そして可能な対策について検討していきます.

RuntimeError: "reciprocal_cuda" not implemented for 'ComplexHalf'

"reciprocal_cuda" not implemented for 'ComplexHalf'

まずはH3について軽く紹介します.
言語モデリングにおいて,State Space Models (SSMs) は一部のタスクで卓越した性能を示していますが,Attentionと比較すると未だ一歩及ばない部分があるのが現状です.そこで,このH3では,SSMsとAttentionの性能差を埋めるためのアーキテクチャとして,SSMを活用しつつAttentionに着想を得たモデルが提案されました.特に,本モデルは,シーケンス内の過去のトークンをうまく扱い,トークン間の比較を強化することを目標としています.

H3はSSMの系譜を順当に受け継いでおり,HiPPO→LSSL→S4→S4D→H3という流れを継いでいます.H3内部でS4Dの計算を行うのですが,S4Dでは複素数を扱う場面があり,その計算過程において以下のエラーが出てしまいました.

RuntimeError: "reciprocal_cuda" not implemented for 'ComplexHalf'

具体的には以下のようなコードです.

https://github.com/HazyResearch/safari/blob/2a11200629daa55158052e0e4d44cf7478ac7331/src/models/sequence/ssm/ss_kernel_diag.py#L169


このエラーメッセージは,torch.complex32 と等価である ComplexHalf という32bit複素数型での逆数計算時に生じているようです.AMPを用いると,適宜fp16とfp32が入れ替わりつつ計算が行われるのですが,上記処理においてはfp16で変数が保持されてしまうようで,それゆえ計算上のエラーが発生することが判明しました.

参考