Skip to content

[Feat] Expose per-step logit values through the model-run API #30

Description

@dlwlzzero

Background / Problem

The public inference API only returns decoded text to callers — either a full string (blocking) or per-token text deltas (streaming). Raw logits are computed internally but never surfaced:

  • Blocking run returns text only: api/quick_dot_ai_api.cpp:1401 (h.last_output = model.getOutput(0)).
  • Streaming run emits UTF-8 token deltas via CausalLmTokenCallback (api/quick_dot_ai_api.h:300).
  • Logits are produced by the LM head (nntrainer/Applications/CausalLM/models/causal_lm.cpp:212-226) and returned from incremental_inference() (causal_lm.cpp:582,619), but inside generate()
    (causal_lm.cpp:300-353) they are immediately consumed by sampling (argmax / temperature / top-k / top-p / softmax) and only the selected token ID is kept. The logit array itself is
    discarded.

There is currently no way for an API consumer to read raw logit values (for eval, debugging, custom sampling, distillation, confidence/uncertainty, constrained decoding research, etc.).

Goal

Add an API path that surfaces per-generation-step logits to the caller, without breaking existing text-only consumers.

Proposed scope

  • Define what is exposed per step: full vocab logits, or top-k (id, logit) pairs. Default to top-k to bound payload size; allow full-vocab opt-in.
  • Decide capture point: pre-sampling raw logits (most useful) vs post-penalty/post-temperature. Recommend exposing raw pre-penalty logits + documenting the order of transforms in
    generate().
  • Plumb logits out of generate()/registerOutputs() (causal_lm.cpp:231-353) via a new optional callback or out-buffer, gated by a flag so the default hot path is unaffected.
  • Add a C API entry (e.g. runModelHandleWithLogitsStreaming(...) alongside runModelHandleWithMessagesStreaming in api/quick_dot_ai_api.h) and a JNI/Kotlin wrapper
    (NativeChatSession.kt).
  • Document QNN-backend behavior: confirm the QNN graph output can return logits (quantized? dequant needed before exposing?).

Acceptance criteria

  • A caller can obtain per-step logits (top-k by default) aligned with each generated token.
  • Existing text/streaming APIs are byte-for-byte unchanged when the new feature is off.
  • Works for at least one nntrainer model and one QNN model (or QNN-limitations documented).

Risks / notes

  • Vocab can be large → full-vocab logits per step are heavy, especially across JNI. Top-k + optional full dump.
  • QNN logits may be quantized; need a defined dequant/scale contract.
  • Be explicit about which stage's values are returned (raw vs penalized vs temperature-scaled), since generate() mutates the logit buffer in place.

Metadata

Metadata

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions