← back

nomic.c

2026-01-27

github.com/hyourindev/nomic.c

純C言語によるNomic Embed Text v1.5の推論

テキスト埋め込みモデルをC言語でフルスクラッチ実装した。PyTorchもONNXもBLASも使っていない。C11とIntelのSIMD組み込み関数だけで、Nomic Embed Text v1.5を単一ファイルで動くように仕上げた。CPU上でHuggingFace Transformersと比較して1.2〜2.1倍の高速化を達成している。

以下では、主要な設計判断の背景、効果のあった最適化、そしていくつかの失敗についてまとめる。

なぜ純Cなのか

埋め込みモデルは、セマンティック検索、RAGパイプライン、クラスタリング、分類といった処理の中核を担っている。ベクトル検索に対応した検索バーにクエリを入力するたびに、どこかで埋め込みモデルが走っている。大抵はPyTorchだ。つまり、2GBのランタイム、Pythonインタプリタ、そして壊れないことを祈るしかない依存関係グラフが付いてくる。

自分が求めていたのは別のものだった。任意のCプログラムから関数をひとつ呼び出し、float配列を受け取って終わり。ランタイムなし。三重の抽象化の裏に隠れたアロケータもなし。動的ディスパッチもなし。#include "nomic.h" して -lm でリンクするだけ。

Nomic Embed v1.5は良い対象だった。パラメータ数1.37億のBERTエンコーダで、重みはfloat32で522MBに収まる程度の規模だ。単純な実装では遅くなり、パフォーマンスについて真剣に考える必要がある、ちょうどよいサイズ感だ。さらにMatryoshka次元にも対応しており、埋め込みを64、128、256、512次元に切り詰めても有用な結果が得られる。実用的な機能なのでサポートする価値がある。

モデルの構造

アーキテクチャはBERTだが、素のBERTではない。Nomicがいくつか変更を加えており、実装に影響する。

学習済みの絶対位置埋め込みの代わりに、Rotary Position Embeddings(RoPE)を使用している。トークン埋め込みに位置ベクトルを加算する代わりに、次元のペアを位置に比例した角度で回転させる。回転周波数にはDynamic NTKスケーリングが適用されており、系列長が学習時の長さを超えた場合に基底周波数を調整する。したがって、静的テーブルの参照ではなく、実際の系列長に基づいてsin/cosテーブルを動的に計算する必要がある。

フィードフォワードネットワークにはGELUの代わりにSwiGLUが使われている。SwiGLUは中間表現を二分割し、片方にSiLU(Sigmoid Linear Unit)を適用してから要素ごとに乗算する。このため、FFNの射影は768→3072が一度ではなく二度(ゲートと値)必要になる。パラメータもFLOPsも増えるが、パラメータあたりの品質は向上する。

QKV射影は融合されている。Q、K、Vそれぞれに768×768の線形層を用意するのではなく、768×2304の単一の射影にまとめている。バイアス項はどこにもない。実装上の些細な点だが、コードは簡潔になる。

残差結合はPost-norm方式を採用している。レイヤー正規化はAttentionやFFNブロックの前ではなく、残差加算の後に適用される。順伝播の順序がわずかに変わる。

最終的な埋め込みパイプラインは、系列全体のmean pooling、最終レイヤー正規化、所望のMatryoshka次元への切り詰め、L2正規化の順で行われる。nomic_embedから出力される埋め込みはすべて単位ノルムであり、内積でそのままコサイン類似度を計算できる。

トークナイザ

トークナイザがこのプロジェクトで最も苦労した部分だった。トークン化の概念が難しいからではない。BERTのトークン化が長年にわたって蓄積してきたエッジケースをすべて正確に再現しなければならないからだ。

パイプラインはUnicode正規化から始まる。具体的にはNFD正規化で、アクセント付き文字を基底文字と結合記号に分解する。その後、結合記号を完全に除去する。つまり「café」は「cafe」になる。語彙がこの方法で構築されているため、トークナイザの挙動が一致しなければ異なるトークンIDが生成され、埋め込みがまったくの別物になる。

正規化の後、CJK文字の前後に空白が挿入される。これはBERT特有の規則で、CJK統合漢字をそれぞれ独立した単語として扱う。その後すべて小文字に変換され、空白と句読点の境界で同時に分割される。

分割された各単語に対してWordPiece分割が行われる。アルゴリズムは語彙中の最長一致接頭辞を探す。単語全体が一致すればそれで良し。一致しなければ、最長の一致接頭辞を出力し、残りに##を付けて処理を続ける。どの接頭辞も一致しなければ[UNK]を出力する。

語彙は30,522エントリある。ソート済み配列に格納し、二分探索で検索している。最初はハッシュテーブルを試したが、この語彙サイズではキャッシュの局所性のおかげでソート済み配列のほうが速かった。各検索がメモリの連続領域にアクセスするのに対し、ハッシュテーブルはヒープ全体に散らばったアクセスになる。

トークン化を完全に正しく動作させるのに、SIMDの実装全体より多くの時間を費やした。トークナイザ専用のテストを20件以上書き、CJKとラテン文字の混在テキスト、連続する句読点を含む文字列、多数のサブワードに分解される単語、空文字列などをカバーした。トークンがひとつでも違えば埋め込みは使い物にならない。緩やかな劣化など存在しない。

SIMD戦略

モデル内の計算負荷の高い処理は、行列積か要素ごとのベクトル演算のいずれかだ。どちらもSIMDに自然に写像できる。

ターゲットはFMA付きAVX2とした。過去10年間に製造されたほぼすべてのx86-64 CPUで利用可能だ。__m256レジスタは8個のfloat32を保持する。モデル内のすべての次元(768, 2304, 3072)が8で割り切れるため、端数処理は一切不要。すべてのループが1回のイテレーションで正確に8個のfloatを処理し、余りもマスクもない。

すべての処理の基盤となる3つのユーティリティ関数を紹介する。

hsum_avx__m256内の8つのfloatを単一のスカラに畳み込む。2回の_mm256_hadd_psの後、上位・下位128ビットレーンを抽出して加算する。すべてのドット積がこの水平加算で終わるため、コードベース中で最も頻繁に使われる処理だ。

exp256_approxは8個の値に対するベクトル化されたexp(x)の近似を同時に計算する。アルゴリズムはexp(x)2^(x/ln2)に分解し、整数部分(IEEE 754指数部のビットシフトになる)と小数部分(4次のHorner多項式で近似)に分ける。精度は約20ビット。softmaxやSiLUでは指数関数の比を計算するため、相対誤差が相殺され、20ビットで十分だ。

rcp_nr_mm256_rcp_ps(12ビット精度)を使って1/xを計算し、Newton-Raphson法の1ステップで精度を倍にする。SiLUのシグモイド計算で_mm256_div_psを回避するために使う。除算は多くのマイクロアーキテクチャで乗算の3〜4倍遅い。

USE_AVX2が定義されていない場合、SIMD層全体が完全に消える。すべてのカーネルに素のCループによるスカラフォールバックがある。つまり、ARM、RISC-V、その他あらゆるアーキテクチャでもビルド・実行が可能だ。ただし遅くなる。

GEMMマイクロカーネル

コードベース全体で最も重要な関数はlinear_no_biasだ。Transformerの全層にある全線形射影の行列積を担っている。Q, K, V射影、出力射影、FFNの拡大射影、FFNの縮小射影。12層で各6回の行列積。この関数が遅ければ、すべてが遅い。

最初の実装は素朴な三重ループだった。正しく動いたが、致命的に遅い。

二番目の実装では、出力ニューロンをひとつずつAVX2のドット積で処理した。改善はしたが、まだ性能を取りこぼしていた。ボトルネックはメモリ帯域幅だ。出力ニューロンごとに重み行列の1行を読み込み、1回だけ使って捨てる。演算密度が低すぎる。

三番目の実装が最終版だ。2S×4Oのマイクロカーネルで、2つの系列位置と4つの出力ニューロンを同時に処理する。内側ループでは、4つの出力行からそれぞれ8個の重みをロードし、2つの系列行からそれぞれ8個の入力値をブロードキャストし、8つのアキュムレータレジスタに対して8回のFMA演算を実行する。

__m256 a00 = _mm256_setzero_ps(), a01 = a00, a02 = a00, a03 = a00;
__m256 a10 = _mm256_setzero_ps(), a11 = a10, a12 = a10, a13 = a10;

for (int k = 0; k < in; k += 8) {
    __m256 x0 = _mm256_load_ps(row0 + k);
    __m256 x1 = _mm256_load_ps(row1 + k);
    __m256 w0 = _mm256_load_ps(W0 + k);
    __m256 w1 = _mm256_load_ps(W1 + k);
    __m256 w2 = _mm256_load_ps(W2 + k);
    __m256 w3 = _mm256_load_ps(W3 + k);

    a00 = _mm256_fmadd_ps(x0, w0, a00);
    a01 = _mm256_fmadd_ps(x0, w1, a01);
    a02 = _mm256_fmadd_ps(x0, w2, a02);
    a03 = _mm256_fmadd_ps(x0, w3, a03);
    a10 = _mm256_fmadd_ps(x1, w0, a10);
    a11 = _mm256_fmadd_ps(x1, w1, a11);
    a12 = _mm256_fmadd_ps(x1, w2, a12);
    a13 = _mm256_fmadd_ps(x1, w3, a13);
}

アキュムレータ8本、入力ロード2回、重みロード4回。16本中14本のYMMレジスタを使用。残り2本はFMAパイプラインの一時領域として使われる。レジスタ圧は高いが、ちょうど収まる。

重要なのは、各重みロードが2つの系列行で共有される点だ。行をひとつずつ処理する場合と比べて、重みの帯域消費が半減する。演算密度が倍になり、カーネルはメモリ律速ではなく演算律速になる。

重み行に対するソフトウェアプリフェッチも追加した。現在位置から16要素(64バイト=1キャッシュライン)先を指定している。

_mm_prefetch((const char *)(W0 + k + 16), _MM_HINT_T0);
_mm_prefetch((const char *)(W1 + k + 16), _MM_HINT_T0);
_mm_prefetch((const char *)(W2 + k + 16), _MM_HINT_T0);
_mm_prefetch((const char *)(W3 + k + 16), _MM_HINT_T0);

次のキャッシュラインの読み込みを必要になる前に開始するようCPUに指示している。ハードウェアプリフェッチャが優秀な最近のCPUでは効果は限定的だが、コストはゼロであり、古いハードウェアでは効く。

並列化

シングルスレッド実装はシンプルだが、CPUの大部分を遊ばせてしまう。モデルにはデータ並列性が豊富にある。OpenMPを採用した。扱いが極めて簡単で、GCCがネイティブにサポートしている。

最初に並列化を導入したのはGEMMカーネルだ。素朴なアプローチは系列位置で並列化することだが、系列が短いとコアが遊んでしまう。10トークンのクエリはCLSとSEPを加えても12×768の行列にしかならない。6ペアの行を6スレッドで並列化すると、各スレッドにちょうど1ペアずつ割り当てられる。負荷分散の余地がない。

そこで2次元タイリングに再構成した。(系列ペア数 × 出力タイル数)の作業を1次元のインデックスに展開し、OpenMPに均等分配させる。

int pairs  = seq / 2;
int otiles = out / 4;
int ntiles = pairs * otiles;

#pragma omp parallel for schedule(static) if(ntiles > 16)
for (int t = 0; t < ntiles; t++) {
    int p  = t / otiles;
    int ot = t % otiles;
    // 行 (p*2, p*2+1) と出力 (ot*4 .. ot*4+3) に対する2S×4Oマイクロカーネル
}

if(ntiles > 16) のガード条件は重要だ。OpenMPにはスレッド生成と同期のオーバーヘッドがある。タイル総数が16未満の小さな入力では、シングルスレッド実行のほうが速い。この閾値は経験的に決定した。

二番目に並列化したのはアテンションヘッドだ。12個のヘッドは完全に独立しており、それぞれQ, K, Vの64次元のスライスを処理する。元の実装は12ヘッドを逐次処理していた。これを並列化したのが、長い系列に対する最大の高速化となった。

#pragma omp parallel for schedule(static) if(seq >= 4)
for (int h = 0; h < num_heads; h++) {
    float *local_q  = amalloc(seq * head_dim * sizeof(float));
    float *local_k  = amalloc(seq * head_dim * sizeof(float));
    float *local_v  = amalloc(seq * head_dim * sizeof(float));
    float *local_sc = amalloc(seq * seq * sizeof(float));

    // gather, K転置, Q*K^T, softmax, scores*V, scatter

    free(local_q);
    free(local_k);
    free(local_v);
    free(local_sc);
}

各スレッドは実際の系列長に基づいて確保された独自のスクラッチバッファを持つ。最初は事前確保したバッファを共有する方式を試したが、それではアクセスの直列化か、最大系列長8192トークン分の確保が必要になる。最大長では、アテンションスコア行列だけで1ヘッドあたり 8192 × 8192 × 4 = 256MB になる。12ヘッド分を事前確保すると、スクラッチ領域だけで3GB消費する。実際の系列長に基づく動的確保なら、メモリ使用量が入力サイズに比例する。

K転置のテクニック

Q×K転置の計算がアテンションのボトルネックだ。素朴な方法では、各スコアをクエリ行とキー行のドット積として計算する。AVX2では、Qから8個、Kから8個のfloatを読み込み、乗算し、水平加算で8つの部分積を1つのスカラに畳み込む。水平加算がコストの高い部分で、複数のシャッフルと加算命令を要する。

乗算の前にKを転置する。Kを [seq][head_dim] ではなく [head_dim][seq] に並べ替える。すると、スコアの1行の計算は通常の行列ベクトル積になる。転置されたKの列から8個の値を読み込み、対応するQの要素をブロードキャストして乗算し、累積する。最後まで水平加算が不要になる。

メモリアクセスパターンが、散発的なロード(K行ごとに1要素)から逐次的なロード(転置後のレイアウトで連続した要素)に変わる。CPUプリフェッチャは、ストライドアクセスよりシーケンシャルアクセスをはるかにうまく処理する。100トークンの系列では、この最適化だけでアテンション処理時間が約30%短縮された。

メモリアライメント

コードベース内のすべてのfloatバッファは32バイト境界でアラインされている。

static void *amalloc(size_t n)
{
    n = (n + 31) & ~(size_t)31;
    if (n == 0) n = 32;
    return aligned_alloc(32, n);
}

丸め処理により、確保サイズは常に32の倍数になる。これはaligned_allocの要件だ。ゼロチェックは未定義動作の回避のため(サイズ0を拒否する実装が存在する)。

アライメントされたメモリにより、_mm256_loadu_psの代わりに_mm256_load_psが使える。最近のCPUではハードウェアが非アラインロードを効率的に処理するため差はほぼない。しかし、一部の古いマイクロアーキテクチャでは、キャッシュライン境界をまたぐ非アラインロードにペナルティが発生する。アライメントのコストは実質ゼロ(確保サイズを切り上げるだけ)なので、やらない理由がない。

ベンチマーク

ベンチマークの公正性確保には、実装そのものとほぼ同じだけの労力を費やした。不公平なベンチマークは、ベンチマークがないよりも悪い。コードが実際より速い、あるいは遅いという誤った認識を生む。

公平性に関する最大の懸念はスレッド数だった。PyTorchはデフォルトでIntel MKLを使用し、コア数を検出してスレッドを生成する。C実装はOpenMPを使い、OMP_NUM_THREADSで同様のことを行う。両方を揃えなければ、意味のない数値になる。すべての実行でtorch.set_num_threadsOMP_NUM_THREADSに合わせた。

二番目の懸念はウォームアップだ。両方の実装にコールドスタートのコストがある。PyTorchにはJITコンパイル、C実装にはキャッシュの温まりがある。計測の前に5回のウォームアップイテレーションを実行し、20回の計測イテレーションの結果を採用した。

三番目の懸念は、何を計測すべきかという点だ。C実装はnomic_embedが文字列を受け取ってfloat配列を返すため、トークン化も計測時間に含まれる。公平を期すため、HuggingFaceもトークン化込みで計測した。加えて、HuggingFaceの推論のみ(トークン化済み)の計測も別途行い、Python側のオーバーヘッドと実際の計算の割合が分かるようにした。

スレッド数を揃えた結果(6スレッド)を示す。

入力 トークン数 nomic.c (ms) HuggingFace (ms) 高速化率
短いクエリ 8 30 64 2.1倍
中程度のクエリ 11 42 76 1.8倍
15-19 51-52 77-89 1.5-1.7倍
短い段落 56 107 123 1.2倍
長い段落 101 137 173 1.3倍
1ページ分 211 255 303 1.2倍

傾向は明確だ。入力が短いほど高速化率が大きい。PythonとPyTorchのオーバーヘッドが全体の処理時間に占める割合が大きくなるためだ。短いクエリでは、PyTorchはフレームワークのオーバーヘッドに実際の行列積以上の時間を費やしている。C実装のオーバーヘッドは実質ゼロだ。関数を呼べば、トークン化してモデルを実行し、結果を返す。

入力が長くなるにつれて実際の計算が支配的になり、高速化率は1.2〜1.3倍に収束する。この差は、手動チューニングしたAVX2カーネルとMKLのGEMMの差だ。MKLは極めて優秀だが、汎用的でもある。自分のカーネルはこのモデルが使う次元に特化している。任意の行列サイズに対応する必要も、異なるデータ型をサポートする必要もない。

学んだこと

手書きSIMDは地道だが機械的な作業だ。命令セットを理解すれば、あとはスカラアルゴリズムをベクトル演算に写像するだけだ。Intel Intrinsics Guideが唯一必要なリファレンスだ。真の難しさは命令そのものではなく、メモリアクセスパターンにある。FLOPs数が少なくてもメモリにシーケンシャルにアクセスするカーネルは、最適な演算をしていてもメモリにランダムアクセスするカーネルに勝る。敵は命令数ではなく、キャッシュミスだ。

トークナイザが最もフラストレーションの溜まる部分だった。Unicode正規化には想像もしなかったコーナーケースがある。結合文字、サロゲートペア、複数のUnicodeブロックにまたがるCJK範囲。すべてのエッジケースが重要だ。トークンはモデルへの入力であり、下流にエラー訂正は存在しない。トークンが1つ違えば、埋め込みも間違う。信頼できるまでに、HuggingFaceのトークナイザとの照合を数百の入力で行った。

OpenMPはこの種のワークロードに対して過小評価されている。複雑なスレッドライブラリに手を伸ばす人が多いが、適切なループに#pragma omp parallel forを一行入れるだけで理論上の速度向上の80%が得られる。2次元タイリングの工夫は全コアを稼働させるために重要だったが、並列化そのものはコード上たった一行だ。

モデルフォーマットは可能な限り単純にすべきだ。重みはフラットなバイナリファイルとして格納している。語彙(30,522エントリ、長さプレフィックス付き文字列)が先頭にあり、その後にすべての重みテンソルがレイヤー順で連結される。メタデータヘッダも、バージョン番号も、圧縮もない。コンバータはHuggingFaceのsafetensorsを読み取ってバイト列を書き出す80行のPythonスクリプトだ。モデルの読み込みはfread一発。単純なフォーマットはデバッグしやすく、検証しやすく、読み込みが速い。

API

全体で約1,100行のC言語コードを単一ファイルに収めている。公開インターフェースは4つの関数だ。

nomic_ctx *nomic_load(const char *model_path);
void       nomic_free(nomic_ctx *ctx);
float     *nomic_embed(nomic_ctx *ctx, const char *text, int dim);
float      nomic_similarity(const float *a, const float *b, int dim);

モデルを読み込む。テキストを埋め込む。埋め込みを比較する。モデルを解放する。dimパラメータでMatryoshkaの切り詰め次元を制御する。768でフル精度、256で容量3分の1かつ品質劣化は最小限、64で最大速度だが精度低下は許容する必要がある。返される埋め込みは次元に関わらずすべてL2正規化済みだ。

#include "nomic.h"
#include <stdio.h>
#include <stdlib.h>

int main(void) {
    nomic_ctx *ctx = nomic_load("nomic.nomicmodel");

    float *a = nomic_embed(ctx, "search_query: What is deep learning?", 768);
    float *b = nomic_embed(ctx, "search_document: Deep learning uses neural networks.", 768);
    float *c = nomic_embed(ctx, "search_document: The recipe calls for flour.", 768);

    printf("relevant:   %.4f\n", nomic_similarity(a, b, 768));
    printf("irrelevant: %.4f\n", nomic_similarity(a, c, 768));

    free(a); free(b); free(c);
    nomic_free(ctx);
}

gcc -O2 -DUSE_AVX2 -mavx2 -mfma -fopenmp でコンパイルし、-lm でリンクすれば、自己完結型の埋め込みエンジンが手に入る。インストールすべき依存関係はない。Pythonの設定も不要。Dockerコンテナを引っ張ってくる必要もない。Cファイルとヘッダとモデルバイナリだけだ。

55件のテストが、トークナイザのエッジケースからHuggingFaceリファレンス実装との埋め込み一致検証まで、パイプライン全体をカバーしている。