@@ -341,7 +341,11 @@ def __init__(
341341 self ._logits_all = logits_all if draft_model is None else True
342342 self .context_params .embeddings = embedding # TODO: Rename to embeddings
343343 self .context_params .offload_kqv = offload_kqv
344- self .context_params .flash_attn = flash_attn
344+ self .context_params .flash_attn_type = (
345+ llama_cpp .LLAMA_FLASH_ATTN_TYPE_ENABLED
346+ if flash_attn
347+ else llama_cpp .LLAMA_FLASH_ATTN_TYPE_DISABLED
348+ )
345349
346350 if op_offload is not None :
347351 self .context_params .op_offload = op_offload
@@ -431,9 +435,9 @@ def free_lora_adapter():
431435
432436 self ._stack .callback (free_lora_adapter )
433437
434- if llama_cpp .llama_set_adapter_lora (
435- self . _ctx . ctx , self . _lora_adapter , self .lora_scale
436- ):
438+ adapters = ( llama_cpp .llama_adapter_lora_p_ctypes * 1 )( self . _lora_adapter )
439+ scales = ( ctypes . c_float * 1 )( self .lora_scale )
440+ if llama_cpp . llama_set_adapters_lora ( self . _ctx . ctx , adapters , 1 , scales ):
437441 raise RuntimeError (
438442 f"Failed to set LoRA adapter from lora path: { self .lora_path } "
439443 )
@@ -726,7 +730,6 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
726730 sampler .add_grammar (self ._model , grammar )
727731
728732 if temp < 0.0 :
729- sampler .add_softmax ()
730733 sampler .add_dist (self ._seed )
731734 elif temp == 0.0 :
732735 sampler .add_greedy ()
@@ -1042,7 +1045,7 @@ def embed(
10421045 data : Union [List [List [float ]], List [List [List [float ]]]] = []
10431046
10441047 def decode_batch (seq_sizes : List [int ]):
1045- llama_cpp . llama_kv_self_clear ( self ._ctx .ctx )
1048+ self ._ctx .kv_cache_clear ( )
10461049 self ._ctx .decode (self ._batch )
10471050 self ._batch .reset ()
10481051
@@ -1113,7 +1116,7 @@ def decode_batch(seq_sizes: List[int]):
11131116
11141117 output = data [0 ] if isinstance (input , str ) else data
11151118
1116- llama_cpp . llama_kv_self_clear ( self ._ctx .ctx )
1119+ self ._ctx .kv_cache_clear ( )
11171120 self .reset ()
11181121
11191122 if return_count :
@@ -2100,7 +2103,10 @@ def __getstate__(self):
21002103 logits_all = self ._logits_all ,
21012104 embedding = self .context_params .embeddings ,
21022105 offload_kqv = self .context_params .offload_kqv ,
2103- flash_attn = self .context_params .flash_attn ,
2106+ flash_attn = (
2107+ self .context_params .flash_attn_type
2108+ == llama_cpp .LLAMA_FLASH_ATTN_TYPE_ENABLED
2109+ ),
21042110 op_offload = self .context_params .op_offload ,
21052111 swa_full = self .context_params .swa_full ,
21062112 # Sampling Params
0 commit comments