nomic.c
2026-01-27Pure C Inference for Nomic Embed Text v1.5
I wrote a text embedding model from scratch in C. No PyTorch, no ONNX, no BLAS libraries. Just C11 and Intel intrinsics. The result is a single-file implementation of Nomic Embed Text v1.5 that runs 1.2 to 2.1x faster than HuggingFace Transformers on CPU.
What follows is the reasoning behind every major decision, every optimization that worked, and a few that didn't.
Why Pure C
Embedding models are the backbone of semantic search, RAG pipelines, clustering, classification. Every time you type a query into a search bar backed by vectors, an embedding model runs somewhere. Usually that means PyTorch. Usually that means pulling in a 2GB runtime, a Python interpreter, and hoping your dependency graph doesn't break.
I wanted something different. I wanted to call one function from any C program, get back a float array, and be done. No runtime. No allocator hidden behind three abstraction layers. No dynamic dispatch. Just #include "nomic.h" and link with -lm.
Nomic Embed v1.5 was the right target. It's a 137M parameter BERT encoder. Small enough that the weights fit in 522MB of float32. Big enough that naive code is slow and you actually have to think about performance. It also supports Matryoshka dimensions, which means you can truncate embeddings to 64, 128, 256, or 512 floats and still get useful results. That's a practical feature worth supporting.
The Model
The architecture is BERT, but not vanilla BERT. Nomic made several changes that affect the implementation.
Rotary Position Embeddings replace learned absolute position embeddings. Instead of adding a position vector to the token embeddings, RoPE rotates pairs of dimensions by an angle proportional to the position. The rotation frequencies use Dynamic NTK scaling, which adjusts the base frequency when the sequence length exceeds the training length. This means I need to compute sin/cos tables dynamically based on actual sequence length, not just look them up from a static table.
The feed-forward network uses SwiGLU instead of GELU. SwiGLU splits the intermediate representation into two halves, applies SiLU (sigmoid linear unit) to one half, and multiplies elementwise. This means the FFN projection goes from 768 to 3072 twice (gate and value), rather than once. More parameters, more FLOPs, but better quality per parameter.
The QKV projection is fused. Instead of three separate 768x768 linear layers for Q, K, and V, there's a single 768x2304 projection. No bias terms anywhere. This is a minor implementation detail but it simplifies the code.
Post-norm residual connections instead of pre-norm. The layer norm comes after the residual add, not before the attention or FFN blocks. This changes the forward pass order slightly.
The final embedding pipeline is mean pooling over the sequence, followed by a final layer norm, truncation to the desired Matryoshka dimension, and L2 normalization. Every embedding that comes out of nomic_embed has unit norm, ready for cosine similarity via dot product.
The Tokenizer
The tokenizer was the hardest part of the project. Not because tokenization is conceptually difficult, but because BERT tokenization has accumulated years of edge cases that all have to match exactly.
The pipeline starts with Unicode normalization. Specifically NFD normalization, which decomposes accented characters into base character plus combining mark. Then accent stripping removes the combining marks entirely. So "café" becomes "cafe". This is important because the vocabulary was built this way, and if your tokenizer doesn't match, you get different token IDs, and different token IDs produce completely wrong embeddings.
After normalization, CJK characters get spaces inserted around them. This is a BERT-specific rule that treats every CJK unified ideograph as its own word. Then everything is lowercased, and the text is split on whitespace and punctuation boundaries simultaneously.
Each resulting word goes through WordPiece splitting. The algorithm tries to match the longest prefix against the vocabulary. If the whole word matches, great. If not, take the longest matching prefix, emit it, and continue with the remainder prefixed by ##. If no prefix matches at all, emit [UNK].
The vocabulary has 30522 entries. I store them as a sorted array and use binary search for lookups. I tried a hash table first, but the sorted array was actually faster for this vocabulary size because of cache locality. Each lookup hits a contiguous region of memory, while the hash table scattered accesses across the heap.
Getting tokenization exactly right took more time than the entire SIMD implementation. I wrote over 20 tokenizer-specific tests, covering things like mixed CJK and Latin text, strings with multiple consecutive punctuation marks, words that decompose into many subword tokens, and the empty string. One wrong token and the embedding is garbage. There's no graceful degradation.
SIMD Strategy
All the compute-heavy operations in the model are either matrix multiplies or element-wise vector operations. Both map naturally to SIMD.
I targeted AVX2 with FMA because it's available on every x86-64 CPU made in the last decade. Each __m256 register holds 8 float32 values. Every dimension in the model (768, 2304, 3072) is divisible by 8, which means no tail handling. Every loop processes exactly 8 floats per iteration with no remainder and no masking.
The three utility functions that everything else builds on:
hsum_avx reduces 8 floats in a __m256 to a single scalar. It uses two _mm256_hadd_ps operations followed by extracting and adding the high and low 128-bit lanes. This is the most common operation in the codebase because every dot product ends with a horizontal sum.
exp256_approx computes a vectorized approximation of exp(x) for 8 values simultaneously. The algorithm decomposes exp(x) into 2^(x/ln2), splits that into an integer part (which becomes a bit shift on the IEEE 754 exponent) and a fractional part (approximated by a degree-4 Horner polynomial). The result has about 20 bits of accuracy. For softmax and SiLU, where we're computing ratios of exponentials, the relative error cancels out and 20 bits is more than enough.
rcp_nr computes 1/x using _mm256_rcp_ps (which gives 12-bit accuracy) followed by one Newton-Raphson refinement step that doubles the precision. I use this in the SiLU sigmoid computation to avoid _mm256_div_ps, which is 3-4x slower than multiply on most microarchitectures.
The entire SIMD layer compiles away to nothing when USE_AVX2 isn't defined. Every kernel has a scalar fallback that uses plain C loops. This means the code builds and runs correctly on ARM, RISC-V, or any other architecture. Just slower.
The GEMM Micro-Kernel
The single most important function in the entire codebase is linear_no_bias. It does the matrix multiply for every linear projection in every transformer layer. Q, K, V projections, output projection, FFN up-projection, FFN down-projection. Twelve layers, six multiplies each. If this function is slow, everything is slow.
My first implementation was the naive triple loop. It was correct and painfully slow.
My second implementation processed one output neuron at a time with an AVX2 dot product. Better, but still leaving performance on the table. The bottleneck was memory bandwidth. Each output neuron loads one row of the weight matrix, uses it once, and discards it. The arithmetic intensity is too low.
The third implementation is what shipped. It's a 2S x 4O micro-kernel that processes 2 sequence positions and 4 output neurons simultaneously. The inner loop loads 8 weights from each of the 4 output rows, broadcasts 8 input values from each of the 2 sequence rows, and performs 8 FMA operations into 8 accumulator registers.
__m256 a00 = _mm256_setzero_ps(), a01 = a00, a02 = a00, a03 = a00;
__m256 a10 = _mm256_setzero_ps(), a11 = a10, a12 = a10, a13 = a10;
for (int k = 0; k < in; k += 8) {
__m256 x0 = _mm256_load_ps(row0 + k);
__m256 x1 = _mm256_load_ps(row1 + k);
__m256 w0 = _mm256_load_ps(W0 + k);
__m256 w1 = _mm256_load_ps(W1 + k);
__m256 w2 = _mm256_load_ps(W2 + k);
__m256 w3 = _mm256_load_ps(W3 + k);
a00 = _mm256_fmadd_ps(x0, w0, a00);
a01 = _mm256_fmadd_ps(x0, w1, a01);
a02 = _mm256_fmadd_ps(x0, w2, a02);
a03 = _mm256_fmadd_ps(x0, w3, a03);
a10 = _mm256_fmadd_ps(x1, w0, a10);
a11 = _mm256_fmadd_ps(x1, w1, a11);
a12 = _mm256_fmadd_ps(x1, w2, a12);
a13 = _mm256_fmadd_ps(x1, w3, a13);
}
That's 8 accumulators, 2 input loads, and 4 weight loads per iteration. 14 YMM registers out of 16 available. The two spare registers are used by the FMA pipeline for temporaries. The register pressure is tight but it fits.
The key insight is that each weight load is shared across both sequence rows. This halves the weight bandwidth compared to processing one row at a time. The arithmetic intensity doubles, and the kernel becomes compute-bound rather than memory-bound.
I also added software prefetching on the weight rows, 16 floats (64 bytes, one cache line) ahead of the current position:
_mm_prefetch((const char *)(W0 + k + 16), _MM_HINT_T0);
_mm_prefetch((const char *)(W1 + k + 16), _MM_HINT_T0);
_mm_prefetch((const char *)(W2 + k + 16), _MM_HINT_T0);
_mm_prefetch((const char *)(W3 + k + 16), _MM_HINT_T0);
This tells the CPU to start fetching the next cache line before we need it. The benefit is modest on modern CPUs with good hardware prefetchers, but it's free to do and helps on older hardware.
Parallelism
A single-threaded implementation is simple but leaves most of the CPU idle. The model has plenty of data parallelism to exploit. I used OpenMP because it's dead simple and GCC supports it natively.
The first place I added parallelism was the GEMM kernel. The naive approach is to parallelize over sequence positions, but that leaves cores idle when the sequence is short. A 10-token query produces a 12x768 matrix after adding CLS and SEP. Parallelizing over 6 pairs of rows across 6 threads gives each thread exactly one pair. No load balancing.
I restructured the parallelism as 2D tiling. Flatten the work into a single index over (seq_pairs x output_tiles) and let OpenMP distribute it evenly:
int pairs = seq / 2;
int otiles = out / 4;
int ntiles = pairs * otiles;
#pragma omp parallel for schedule(static) if(ntiles > 16)
for (int t = 0; t < ntiles; t++) {
int p = t / otiles;
int ot = t % otiles;
// 2Sx4O micro-kernel for rows (p*2, p*2+1) and outputs (ot*4 .. ot*4+3)
}
The if(ntiles > 16) guard is important. OpenMP has overhead for thread creation and synchronization. For tiny inputs where the total work is less than 16 tiles, single-threaded execution is faster. This threshold was determined empirically.
The second place I added parallelism was the attention heads. The model has 12 attention heads that are completely independent. Each head operates on a 64-dimensional slice of Q, K, and V. The original implementation processed all 12 sequentially. Making them parallel was the single biggest speedup for long sequences.
#pragma omp parallel for schedule(static) if(seq >= 4)
for (int h = 0; h < num_heads; h++) {
float *local_q = amalloc(seq * head_dim * sizeof(float));
float *local_k = amalloc(seq * head_dim * sizeof(float));
float *local_v = amalloc(seq * head_dim * sizeof(float));
float *local_sc = amalloc(seq * seq * sizeof(float));
// gather, K-transpose, Q*K^T, softmax, scores*V, scatter
free(local_q);
free(local_k);
free(local_v);
free(local_sc);
}
Each thread gets its own scratch buffers, allocated based on the actual sequence length. I initially tried sharing pre-allocated buffers, but that meant either serializing access or allocating for the maximum sequence length of 8192 tokens. At max length, the attention score matrix alone is 8192 * 8192 * 4 = 256MB per head. Allocating 12 of those upfront would consume 3GB just for scratch space. Dynamic allocation based on actual sequence length keeps memory usage proportional to input size.
The K-Transpose Trick
Computing Q times K-transpose is the attention bottleneck. The naive approach computes each score as a dot product between a query row and a key row. With AVX2, that means loading 8 floats from Q, 8 from K, multiplying, and then doing a horizontal sum to reduce 8 partial products to one scalar. The horizontal sum is the expensive part. It takes multiple shuffle and add instructions.
I transpose K before the multiplication. Instead of K being stored as [seq][head_dim], I rearrange it to [head_dim][seq]. Now computing one row of scores is a standard matrix-vector multiply: load 8 values from the transposed K column, multiply by the corresponding Q element (broadcast), and accumulate. No horizontal sum until the very end.
This changes the memory access pattern from scattered loads (one element per K row) to sequential loads (contiguous elements in the transposed layout). The CPU prefetcher handles sequential access much better than strided access. On a 100-token sequence, this optimization alone cut attention time by about 30%.
Memory Alignment
Every float buffer in the codebase is allocated with 32-byte alignment:
static void *amalloc(size_t n)
{
n = (n + 31) & ~(size_t)31;
if (n == 0) n = 32;
return aligned_alloc(32, n);
}
The rounding ensures the allocation size is always a multiple of 32, which is required by aligned_alloc. The zero check prevents undefined behavior (some implementations reject size 0).
Aligned memory lets me use _mm256_load_ps instead of _mm256_loadu_ps. On modern CPUs the difference is negligible because the hardware handles unaligned loads efficiently. But on some older microarchitectures, unaligned loads that cross a cache line boundary incur a penalty. Since alignment is free (just round up the allocation), there's no reason not to do it.
Benchmarks
I spent almost as much time getting the benchmarks right as I did on the implementation. Unfair benchmarks are worse than no benchmarks. They mislead you into thinking your code is faster or slower than it actually is.
The main fairness concern was thread count. PyTorch uses Intel MKL by default, which detects the number of cores and spawns threads accordingly. My C implementation uses OpenMP, which does the same via OMP_NUM_THREADS. If you don't control both, you get meaningless numbers. I set torch.set_num_threads to match OMP_NUM_THREADS for every run.
The second concern was warmup. Both implementations have cold-start costs. PyTorch has JIT compilation. My code has cache warming. I run 5 warmup iterations before timing 20 measured iterations.
The third concern was measuring the right thing. The C implementation includes tokenization in its timing because nomic_embed takes a string and returns a float array. To be fair, I measured HuggingFace with tokenization included too. I also measured HuggingFace inference-only (pre-tokenized) separately, so you can see how much of the time is Python overhead versus actual compute.
Here are the results at matched thread count (6 threads):
| Input | Tokens | nomic.c (ms) | HuggingFace (ms) | Speedup |
|---|---|---|---|---|
| Short query | 8 | 30 | 64 | 2.1x |
| Medium query | 11 | 42 | 76 | 1.8x |
| Sentence | 15-19 | 51-52 | 77-89 | 1.5-1.7x |
| Short paragraph | 56 | 107 | 123 | 1.2x |
| Long paragraph | 101 | 137 | 173 | 1.3x |
| Full page | 211 | 255 | 303 | 1.2x |
The pattern is clear. Short inputs show the biggest speedup because Python and PyTorch overhead is a larger fraction of the total time. For a short query, PyTorch spends more time in framework overhead than in actual matrix multiplication. The C implementation has essentially zero overhead. You call the function, it tokenizes, runs the model, and returns.
As inputs get longer, the actual compute dominates and the speedup converges to 1.2-1.3x. That remaining gap is the difference between hand-tuned AVX2 kernels and MKL's GEMM. MKL is extremely good, but it's also general-purpose. My kernels are specialized for the exact dimensions this model uses. I don't handle arbitrary matrix sizes, I don't support different data types, and I don't need to.
What I Learned
Writing SIMD by hand is tedious but mechanical. Once you understand the instruction set, it's just a matter of mapping the scalar algorithm to vector operations. The Intel Intrinsics Guide is the only reference you need. The real difficulty is not the instructions themselves but the memory access patterns. A kernel that does fewer FLOPs but accesses memory sequentially will beat a kernel that does optimal arithmetic but touches memory randomly. Cache misses are the enemy, not instruction count.
The tokenizer was the most frustrating part. Unicode normalization has corner cases I never imagined. Combining characters, surrogate pairs, CJK ranges that span multiple Unicode blocks. Every edge case matters because tokens are the input to the model and there's no error correction downstream. If you get one token wrong, the embedding is wrong. I verified against HuggingFace's tokenizer on hundreds of inputs before I trusted it.
OpenMP is underrated for this kind of workload. People reach for complex threading libraries when a simple #pragma omp parallel for on the right loop gives 80% of the theoretical speedup. The 2D tiling trick was important for keeping all cores busy, but the actual parallelism is just one line of code.
The model format should be as dumb as possible. I store weights as a flat binary file. The vocabulary goes first (30522 entries, length-prefixed strings), then every weight tensor concatenated in layer order. No metadata headers, no version numbers, no compression. The converter is 80 lines of Python that reads HuggingFace safetensors and writes bytes. Loading the model is a single fread. Simple formats are easy to debug, easy to verify, and fast to load.
The API
The whole thing is about 1100 lines of C in a single file. The public interface is four functions:
nomic_ctx *nomic_load(const char *model_path);
void nomic_free(nomic_ctx *ctx);
float *nomic_embed(nomic_ctx *ctx, const char *text, int dim);
float nomic_similarity(const float *a, const float *b, int dim);
Load a model. Embed text. Compare embeddings. Free the model. The dim parameter controls Matryoshka truncation. Pass 768 for full precision, 256 for 3x less storage with minimal quality loss, or 64 if you need maximum speed and can tolerate some accuracy degradation. All returned embeddings are L2-normalized regardless of dimension.
#include "nomic.h"
#include <stdio.h>
#include <stdlib.h>
int main(void) {
nomic_ctx *ctx = nomic_load("nomic.nomicmodel");
float *a = nomic_embed(ctx, "search_query: What is deep learning?", 768);
float *b = nomic_embed(ctx, "search_document: Deep learning uses neural networks.", 768);
float *c = nomic_embed(ctx, "search_document: The recipe calls for flour.", 768);
printf("relevant: %.4f\n", nomic_similarity(a, b, 768));
printf("irrelevant: %.4f\n", nomic_similarity(a, c, 768));
free(a); free(b); free(c);
nomic_free(ctx);
}
Compile with gcc -O2 -DUSE_AVX2 -mavx2 -mfma -fopenmp, link -lm, and you have a self-contained embedding engine. No dependencies to install. No Python to configure. No Docker containers to pull. Just a C file, a header, and a model binary.
55 tests cover the full pipeline from tokenizer edge cases to end-to-end embedding verification against the HuggingFace reference implementation.