diff --git a/CHANGELOG.md b/CHANGELOG.md index 264e0160..80de7c1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ from version 5.0.0 onward. Pre-fork releases (`1.x`–`4.2.0`) were authored by - Explicit `setMmprojAuto(boolean)` and `setMmprojOffload(boolean)` controls, including the upstream `--no-mmproj-auto` and `--no-mmproj-offload` flags. - Per-request KV controls: `InferenceParameters.withSlotId(int)` and `withCacheReuse(int)`. - Per-request DRY sampling to `InferenceParameters` (`dry_multiplier`/`dry_base`/`dry_allowed_length`/`dry_penalty_last_n`/`dry_sequence_breakers`). +- `ModelParameters.enableSwaFull()` (`--swa-full`): keep full-size SWA KV cache to enable cross-request prompt-prefix reuse. - Typed cache observability through `Usage.getCachedTokens()`, `Usage.getProcessedPromptTokens()`, `SlotMetrics`, and `ServerMetrics.getSlotMetrics()`. - Authenticated JSON `GET /metrics` and `GET /slots` endpoints on the embedded server. diff --git a/src/main/java/net/ladenthin/llama/args/ModelFlag.java b/src/main/java/net/ladenthin/llama/args/ModelFlag.java index f9ac8aee..17616a9d 100644 --- a/src/main/java/net/ladenthin/llama/args/ModelFlag.java +++ b/src/main/java/net/ladenthin/llama/args/ModelFlag.java @@ -22,6 +22,11 @@ public enum ModelFlag { /** Enable Flash Attention. */ FLASH_ATTN("--flash-attn"), + /** Keep the full-size sliding-window-attention (SWA) KV cache, enabling cross-request + * prompt-prefix reuse (pairs with --cache-reuse) at ~2x the SWA-layer KV RAM. Default off. + * Env: LLAMA_ARG_SWA_FULL. */ + SWA_FULL("--swa-full"), + /** Disable internal libllama performance timings. */ NO_PERF("--no-perf"), diff --git a/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java b/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java index 50864145..ce62131b 100644 --- a/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java @@ -255,6 +255,17 @@ public ModelParameters enableFlashAttn() { return setFlag(ModelFlag.FLASH_ATTN); } + /** + * Use the full-size SWA KV cache so the sliding-window layers' KV is reusable across requests + * (restores prompt-prefix cache reuse with {@link #setCacheReuse(int)}); costs ~2x SWA-layer + * KV RAM. Off by default; only beneficial for multi-request sessions sharing a prompt prefix. + * + * @return this builder + */ + public ModelParameters enableSwaFull() { + return setFlag(ModelFlag.SWA_FULL); + } + /** * Disable internal libllama performance timings (default: false). * diff --git a/src/test/java/net/ladenthin/llama/args/ModelFlagTest.java b/src/test/java/net/ladenthin/llama/args/ModelFlagTest.java index 50401840..2621c0fe 100644 --- a/src/test/java/net/ladenthin/llama/args/ModelFlagTest.java +++ b/src/test/java/net/ladenthin/llama/args/ModelFlagTest.java @@ -19,6 +19,7 @@ public static Collection data() { return Arrays.asList(new Object[][] { {ModelFlag.NO_CONTEXT_SHIFT, "--no-context-shift"}, {ModelFlag.FLASH_ATTN, "--flash-attn"}, + {ModelFlag.SWA_FULL, "--swa-full"}, {ModelFlag.NO_PERF, "--no-perf"}, {ModelFlag.ESCAPE, "--escape"}, {ModelFlag.NO_ESCAPE, "--no-escape"}, @@ -66,7 +67,7 @@ public void testGetCliFlag(ModelFlag flag, String expectedCliFlag) { @Test public void testEnumCount() { - assertEquals(34, ModelFlag.values().length); + assertEquals(35, ModelFlag.values().length); } @ParameterizedTest(name = "{0} -> {1}") diff --git a/src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java b/src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java index 752c1031..bc4dc3aa 100644 --- a/src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java @@ -641,6 +641,18 @@ public void testEnableFlashAttn() { assertThat(p.parameters.get("--flash-attn"), is(nullValue())); } + @Test + public void testEnableSwaFull() { + ModelParameters p = new ModelParameters().enableSwaFull(); + assertThat(p.parameters, hasKey("--swa-full")); + assertThat(p.parameters.get("--swa-full"), is(nullValue())); + } + + @Test + public void testSwaFullNotEnabledByDefault() { + assertThat(new ModelParameters().parameters, not(hasKey("--swa-full"))); + } + @Test public void testDisablePerf() { ModelParameters p = new ModelParameters().disablePerf();