概要
最近,耳目を集めている言語モデルにHyenaと呼ばれるものがあります.このHyenaのベースとなっているモデルにHungry Hungry Hippos (通称H3)があるのですが,弊社でもこのH3を活用する流れがあります.
弊社の機械学習チームにて,H3論文に関する動作確認を進めている中で,特定の条件下で計算上の問題に直面しました.本記事では,そのエラーの原因と背景,そして可能な対策について検討していきます.
RuntimeError: "reciprocal_cuda" not implemented for 'ComplexHalf'
"reciprocal_cuda" not implemented for 'ComplexHalf'
まずはH3について軽く紹介します.
言語モデリングにおいて,State Space Models (SSMs) は一部のタスクで卓越した性能を示していますが,Attentionと比較すると未だ一歩及ばない部分があるのが現状です.そこで,このH3では,SSMsとAttentionの性能差を埋めるためのアーキテクチャとして,SSMを活用しつつAttentionに着想を得たモデルが提案されました.特に,本モデルは,シーケンス内の過去のトークンをうまく扱い,トークン間の比較を強化することを目標としています.
H3はSSMの系譜を順当に受け継いでおり,HiPPO→LSSL→S4→S4D→H3という流れを継いでいます.H3内部でS4Dの計算を行うのですが,S4Dでは複素数を扱う場面があり,その計算過程において以下のエラーが出てしまいました.
RuntimeError: "reciprocal_cuda" not implemented for 'ComplexHalf'
具体的には以下のようなコードです.
このエラーメッセージは,torch.complex32 と等価である ComplexHalf という32bit複素数型での逆数計算時に生じているようです.AMPを用いると,適宜fp16とfp32が入れ替わりつつ計算が行われるのですが,上記処理においてはfp16で変数が保持されてしまうようで,それゆえ計算上のエラーが発生することが判明しました.
参考