diff --git a/.editorconfig b/.editorconfig index b84e563..3d80c19 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,194 +1,31 @@ -# With more recent updates Visual Studio 2017 supports EditorConfig files out of the box -# Visual Studio Code needs an extension: https://github.com/editorconfig/editorconfig-vscode -# For emacs, vim, np++ and other editors, see here: https://github.com/editorconfig -############################### -# Core EditorConfig Options # -############################### root = true -# All files + [*] indent_style = space indent_size = 4 +end_of_line = lf charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true -end_of_line = lf -max_line_length = off - -# YAML indentation -[*.{yml,yaml}] -indent_size = 2 -# XML indentation -[*.{csproj,xml}] -indent_size = 2 - -############################### -# .NET Coding Conventions # -############################### -[*.{cs,vb}] -# Organize usings -dotnet_sort_system_directives_first = true -# this. preferences +[*.{cs,csproj,props}] +indent_size = 4 +dotnet_style_qualification_for_event = false:silent dotnet_style_qualification_for_field = false:silent -dotnet_style_qualification_for_property = false:silent dotnet_style_qualification_for_method = false:silent -dotnet_style_qualification_for_event = false:silent -# Language keywords vs BCL types preferences +dotnet_style_qualification_for_property = false:silent dotnet_style_predefined_type_for_locals_parameters_members = true:silent dotnet_style_predefined_type_for_member_access = true:silent -# Parentheses preferences -dotnet_style_parentheses_in_arithmetic_binary_operators = always_for_clarity:silent -dotnet_style_parentheses_in_relational_binary_operators = always_for_clarity:silent -dotnet_style_parentheses_in_other_binary_operators = always_for_clarity:silent -dotnet_style_parentheses_in_other_operators = never_if_unnecessary:silent -# Modifier preferences -dotnet_style_require_accessibility_modifiers = for_non_interface_members:silent -dotnet_style_readonly_field = true:suggestion -# Expression-level preferences -dotnet_style_object_initializer = true:suggestion -dotnet_style_collection_initializer = true:suggestion -dotnet_style_explicit_tuple_names = true:suggestion -dotnet_style_null_propagation = true:suggestion -dotnet_style_coalesce_expression = true:suggestion -dotnet_style_prefer_is_null_check_over_reference_equality_method = true:silent -dotnet_style_prefer_inferred_tuple_names = true:suggestion -dotnet_style_prefer_inferred_anonymous_type_member_names = true:suggestion -dotnet_style_prefer_auto_properties = true:silent -dotnet_style_prefer_conditional_expression_over_assignment = true:silent -dotnet_style_prefer_conditional_expression_over_return = true:silent - -############################### -# Naming Conventions # -############################### -# Style Definitions (From Roslyn) - -# Non-private static fields are PascalCase -dotnet_naming_rule.non_private_static_fields_should_be_pascal_case.severity = suggestion -dotnet_naming_rule.non_private_static_fields_should_be_pascal_case.symbols = non_private_static_fields -dotnet_naming_rule.non_private_static_fields_should_be_pascal_case.style = non_private_static_field_style - -dotnet_naming_symbols.non_private_static_fields.applicable_kinds = field -dotnet_naming_symbols.non_private_static_fields.applicable_accessibilities = public, protected, internal, protected_internal, private_protected -dotnet_naming_symbols.non_private_static_fields.required_modifiers = static - -dotnet_naming_style.non_private_static_field_style.capitalization = pascal_case - -# Constants are PascalCase -dotnet_naming_rule.constants_should_be_pascal_case.severity = suggestion -dotnet_naming_rule.constants_should_be_pascal_case.symbols = constants -dotnet_naming_rule.constants_should_be_pascal_case.style = constant_style - -dotnet_naming_symbols.constants.applicable_kinds = field, local -dotnet_naming_symbols.constants.required_modifiers = const -dotnet_naming_style.constant_style.capitalization = pascal_case - -# Static fields are camelCase and start with s_ -dotnet_naming_rule.static_fields_should_be_camel_case.severity = suggestion -dotnet_naming_rule.static_fields_should_be_camel_case.symbols = static_fields -dotnet_naming_rule.static_fields_should_be_camel_case.style = static_field_style - -dotnet_naming_symbols.static_fields.applicable_kinds = field -dotnet_naming_symbols.static_fields.required_modifiers = static - -dotnet_naming_style.static_field_style.capitalization = camel_case -dotnet_naming_style.static_field_style.required_prefix = _ - -# Instance fields are camelCase and start with _ -dotnet_naming_rule.instance_fields_should_be_camel_case.severity = suggestion -dotnet_naming_rule.instance_fields_should_be_camel_case.symbols = instance_fields -dotnet_naming_rule.instance_fields_should_be_camel_case.style = instance_field_style - -dotnet_naming_symbols.instance_fields.applicable_kinds = field - -dotnet_naming_style.instance_field_style.capitalization = camel_case -dotnet_naming_style.instance_field_style.required_prefix = _ - -# Locals and parameters are camelCase -dotnet_naming_rule.locals_should_be_camel_case.severity = suggestion -dotnet_naming_rule.locals_should_be_camel_case.symbols = locals_and_parameters -dotnet_naming_rule.locals_should_be_camel_case.style = camel_case_style - -dotnet_naming_symbols.locals_and_parameters.applicable_kinds = parameter, local - -dotnet_naming_style.camel_case_style.capitalization = camel_case - -# Local functions are PascalCase -dotnet_naming_rule.local_functions_should_be_pascal_case.severity = suggestion -dotnet_naming_rule.local_functions_should_be_pascal_case.symbols = local_functions -dotnet_naming_rule.local_functions_should_be_pascal_case.style = local_function_style - -dotnet_naming_symbols.local_functions.applicable_kinds = local_function - -dotnet_naming_style.local_function_style.capitalization = pascal_case - -# By default, name items with PascalCase -dotnet_naming_rule.members_should_be_pascal_case.severity = suggestion -dotnet_naming_rule.members_should_be_pascal_case.symbols = all_members -dotnet_naming_rule.members_should_be_pascal_case.style = pascal_case_style - -dotnet_naming_symbols.all_members.applicable_kinds = * +[*.{xml,config,json}] +indent_size = 2 -dotnet_naming_style.pascal_case_style.capitalization = pascal_case +[*.{md,txt}] +indent_size = 4 +trim_trailing_whitespace = false -############################### -# C# Coding Conventions # -############################### -[*.cs] -# var preferences -csharp_style_var_for_built_in_types = true:silent -csharp_style_var_when_type_is_apparent = true:silent -csharp_style_var_elsewhere = true:silent -# Expression-bodied members -csharp_style_expression_bodied_methods = false:silent -csharp_style_expression_bodied_constructors = false:silent -csharp_style_expression_bodied_operators = false:silent -csharp_style_expression_bodied_properties = true:silent -csharp_style_expression_bodied_indexers = true:silent -csharp_style_expression_bodied_accessors = true:silent -# Pattern matching preferences -csharp_style_pattern_matching_over_is_with_cast_check = true:suggestion -csharp_style_pattern_matching_over_as_with_null_check = true:suggestion -# Null-checking preferences -csharp_style_throw_expression = true:suggestion -csharp_style_conditional_delegate_call = true:suggestion -# Modifier preferences -csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,volatile,async:suggestion -# Expression-level preferences -csharp_prefer_braces = true:silent -csharp_style_deconstructed_variable_declaration = true:suggestion -csharp_prefer_simple_default_expression = true:suggestion -csharp_style_pattern_local_over_anonymous_function = true:suggestion -csharp_style_inlined_variable_declaration = true:suggestion +[Makefile] +indent_style = tab -############################### -# C# Formatting Rules # -############################### -# New line preferences -csharp_new_line_before_open_brace = all -csharp_new_line_before_else = true -csharp_new_line_before_catch = true -csharp_new_line_before_finally = true -csharp_new_line_before_members_in_object_initializers = true -csharp_new_line_before_members_in_anonymous_types = true -csharp_new_line_between_query_expression_clauses = true -# Indentation preferences -csharp_indent_case_contents = true -csharp_indent_switch_labels = true -csharp_indent_labels = flush_left -# Space preferences -csharp_space_after_cast = false -csharp_space_after_keywords_in_control_flow_statements = true -csharp_space_between_method_call_parameter_list_parentheses = false -csharp_space_between_method_declaration_parameter_list_parentheses = false -csharp_space_between_parentheses = false -csharp_space_before_colon_in_inheritance_clause = true -csharp_space_after_colon_in_inheritance_clause = true -csharp_space_around_binary_operators = before_and_after -csharp_space_between_method_declaration_empty_parameter_list_parentheses = false -csharp_space_between_method_call_name_and_opening_parenthesis = false -csharp_space_between_method_call_empty_parameter_list_parentheses = false -# Wrapping preferences -csharp_preserve_single_line_statements = true -csharp_preserve_single_line_blocks = true +[*.sh] +indent_size = 4 diff --git a/.github/rulesets/main.json b/.github/rulesets/main.json new file mode 100644 index 0000000..862334f --- /dev/null +++ b/.github/rulesets/main.json @@ -0,0 +1,65 @@ +[ + { + "name": "main - Release Branch", + "target": "branch", + "enforcement": "active", + "conditions": { + "ref_name": { + "include": ["refs/heads/main"], + "exclude": [] + } + }, + "rules": [ + { "type": "deletion" }, + { "type": "non_fast_forward" }, + { "type": "required_linear_history" }, + { + "type": "pull_request", + "parameters": { + "required_approving_review_count": 1, + "dismiss_stale_reviews_on_push": true, + "require_code_owner_review": false, + "require_last_push_approval": true + } + } + ], + "bypass_actors": [ + { + "actor_id": 0, + "actor_type": "RepositoryRole", + "bypass_mode": "always" + } + ] + }, + { + "name": "All branches - Tag Protection", + "target": "tag", + "enforcement": "active", + "conditions": { + "ref_name": { + "include": ["refs/tags/v*"], + "exclude": [] + } + }, + "rules": [ + { "type": "deletion" }, + { "type": "non_fast_forward" } + ] + }, + { + "name": "Release Tags - Signature", + "target": "tag", + "enforcement": "active", + "conditions": { + "ref_name": { + "include": ["refs/tags/v[0-9]*.[0-9]*.[0-9]*.[0-9]*"], + "exclude": [] + } + }, + "rules": [ + { + "type": "required_signatures" + } + ] + } +] diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml deleted file mode 100644 index f290747..0000000 --- a/.github/workflows/build.yaml +++ /dev/null @@ -1,18 +0,0 @@ -name: '๐๏ธ Build Plugin' - -on: - push: - branches: - - master - paths-ignore: - - '**/*.md' - pull_request: - branches: - - master - paths-ignore: - - '**/*.md' - workflow_dispatch: - -jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/build.yaml@master diff --git a/.github/workflows/changelog.yaml b/.github/workflows/changelog.yaml deleted file mode 100644 index 5b3c3be..0000000 --- a/.github/workflows/changelog.yaml +++ /dev/null @@ -1,20 +0,0 @@ -name: '๐ Create/Update Release Draft & Release Bump PR' - -on: - push: - branches: - - master - paths-ignore: - - build.yaml - workflow_dispatch: - repository_dispatch: - types: - - update-prep-command - -jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/changelog.yaml@master - with: - repository-name: jellyfin/jellyfin-plugin-template - secrets: - token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/command-dispatch.yaml b/.github/workflows/command-dispatch.yaml deleted file mode 100644 index 1b5e4ee..0000000 --- a/.github/workflows/command-dispatch.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# Allows for the definition of PR and Issue /commands -name: '๐ Slash Command Dispatcher' - -on: - issue_comment: - types: - - created - -jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/command-dispatch.yaml@master - secrets: - token: . diff --git a/.github/workflows/command-rebase.yaml b/.github/workflows/command-rebase.yaml deleted file mode 100644 index 7847e20..0000000 --- a/.github/workflows/command-rebase.yaml +++ /dev/null @@ -1,16 +0,0 @@ -name: '๐ PR Rebase Command' - -on: - repository_dispatch: - types: - - rebase-command - -jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/command-rebase.yaml@master - with: - rebase-head: ${{ github.event.client_payload.pull_request.head.label }} - repository-full-name: ${{ github.event.client_payload.github.payload.repository.full_name }} - comment-id: ${{ github.event.client_payload.github.payload.comment.id }} - secrets: - token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml deleted file mode 100644 index 80483cf..0000000 --- a/.github/workflows/publish.yaml +++ /dev/null @@ -1,18 +0,0 @@ -name: '๐ Publish Plugin' - -on: - release: - types: - - released - workflow_dispatch: - -jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/publish.yaml@master - with: - version: ${{ github.event.release.tag_name }} - is-unstable: ${{ github.event.release.prerelease }} - secrets: - deploy-host: ${{ secrets.DEPLOY_HOST }} - deploy-user: ${{ secrets.DEPLOY_USER }} - deploy-key: ${{ secrets.DEPLOY_KEY }} diff --git a/.github/workflows/scan-codeql.yaml b/.github/workflows/scan-codeql.yaml deleted file mode 100644 index ca8b0b0..0000000 --- a/.github/workflows/scan-codeql.yaml +++ /dev/null @@ -1,20 +0,0 @@ -name: '๐ฌ Run CodeQL' - -on: - push: - branches: [ master ] - paths-ignore: - - '**/*.md' - pull_request: - branches: [ master ] - paths-ignore: - - '**/*.md' - schedule: - - cron: '24 2 * * 4' - workflow_dispatch: - -jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/scan-codeql.yaml@master - with: - repository-name: jellyfin/jellyfin-plugin-template diff --git a/.github/workflows/sync-labels.yaml b/.github/workflows/sync-labels.yaml deleted file mode 100644 index 5e06ae4..0000000 --- a/.github/workflows/sync-labels.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: '๐ท๏ธ Sync labels' - -on: - schedule: - - cron: '0 0 1 * *' - workflow_dispatch: - -jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/sync-labels.yaml@master - secrets: - token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d90b14d..2a48d66 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,18 +1,27 @@ -name: '๐งช Test Plugin' +name: '๐งช Build & Test Plugin' on: push: - branches: - - master - paths-ignore: - - '**/*.md' + branches: [master] + paths-ignore: ['**/*.md'] pull_request: - branches: - - master - paths-ignore: - - '**/*.md' + branches: [master] + paths-ignore: ['**/*.md'] workflow_dispatch: jobs: - call: - uses: jellyfin/jellyfin-meta-plugins/.github/workflows/test.yaml@master + build: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v6 + + - name: Setup .NET + uses: actions/setup-dotnet@v5 + with: + dotnet-version: '9.0.x' + + - name: Restore dependencies + run: dotnet restore Jellyfin.Plugin.WhisperSubtitles.sln + + - name: Build + run: dotnet build Jellyfin.Plugin.WhisperSubtitles.sln --configuration Release --no-restore -warnaserror diff --git a/.gitignore b/.gitignore index 0b72c24..9d1adc2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,20 @@ +*.backup +*.py.backup +.venv bin/ obj/ +releases/ +wiki/ +whisper.cpp/ + +BenchmarkDotNet.Artifacts/ +/package/ +*.lscache + +# Visual Studio .vs/ + +# JetBrains Rider .idea/ -artifacts +/.cursor +.agents/ diff --git a/.vscode/settings.json b/.vscode/settings.json index 7fa6075..9a72214 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,19 +1,34 @@ { - // jellyfinDir : The directory of the cloned jellyfin server project - // This needs to be built once before it can be used - "jellyfinDir": "${workspaceFolder}/../jellyfin/Jellyfin.Server", - // jellyfinWebDir : The directory of the cloned jellyfin-web project - // This needs to be built once before it can be used - "jellyfinWebDir": "${workspaceFolder}/../jellyfin-web", - // jellyfinDataDir : the root data directory for a running jellyfin instance - // This is where jellyfin stores its configs, plugins, metadata etc - // This is platform specific by default, but on Windows defaults to - // ${env:LOCALAPPDATA}/jellyfin - // and on Linux, it defaults to - // ${env:XDG_DATA_HOME}/jellyfin - // However ${env:XDG_DATA_HOME} does not work in Visual Studio Code's development container! - "jellyfinWindowsDataDir": "${env:LOCALAPPDATA}/jellyfin", - "jellyfinLinuxDataDir": "$HOME/.local/share/jellyfin", - // The name of the plugin - "pluginName": "Jellyfin.Plugin.Template", -} \ No newline at end of file + "dotnet.defaultSolution": "Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles.sln", + "editor.formatOnSave": true, + "editor.rulers": [120], + "files.exclude": { + "**/bin/": true, + "**/obj/": true, + "**/.git": true, + "releases/": true, + "wiki/": true, + "whisper.cpp/": true + }, + "search.exclude": { + "**/bin/": true, + "**/obj/": true, + "releases/": true, + "wiki/": true, + "whisper.cpp/": true + }, + "files.associations": { + "*.props": "xml", + "*.targets": "xml" + }, + "[csharp]": { + "editor.defaultFormatter": "ms-dotnettools.csharp", + "editor.tabSize": 4 + }, + "[xml]": { + "editor.tabSize": 2 + }, + "[json]": { + "editor.tabSize": 2 + } +} diff --git a/Directory.Build.props b/Directory.Build.props index c702921..471a913 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,7 +1,14 @@ - - 0.0.0.0 - 0.0.0.0 - 0.0.0.0 - + + 3.0.0.0 + 3.0.0.0 + 3.0.0.0 + Jellyfin.Plugin.WhisperSubtitles + Whisper Plugin Contributors + Automatic subtitle generation for Jellyfin using OpenAI's Whisper + https://github.com/zakattack02/Whisper-Script + https://github.com/zakattack02/Whisper-Script + git + MIT + diff --git a/Jellyfin.Plugin.WhisperSubtitles.sln b/Jellyfin.Plugin.WhisperSubtitles.sln new file mode 100644 index 0000000..348fee4 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles.sln @@ -0,0 +1,19 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.31919.166 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Jellyfin.Plugin.WhisperSubtitles", "Jellyfin.Plugin.WhisperSubtitles\Jellyfin.Plugin.WhisperSubtitles\Jellyfin.Plugin.WhisperSubtitles.csproj", "{B7E1E1E1-1E1E-1E1E-1E1E-1E1E1E1E1E1E}" +EndProject + +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {B7E1E1E1-1E1E-1E1E-1E1E-1E1E1E1E1E1E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B7E1E1E1-1E1E-1E1E-1E1E-1E1E1E1E1E1E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B7E1E1E1-1E1E-1E1E-1E1E-1E1E1E1E1E1E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B7E1E1E1-1E1E-1E1E-1E1E-1E1E1E1E1E1E}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection +EndGlobal diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/Logo.png b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/Logo.png new file mode 100644 index 0000000..08a4624 Binary files /dev/null and b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/Logo.png differ diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/PluginConfiguration.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/PluginConfiguration.cs new file mode 100644 index 0000000..7ae05ed --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/PluginConfiguration.cs @@ -0,0 +1,132 @@ +using System; +using System.Collections.Generic; +using MediaBrowser.Model.Plugins; + +namespace Jellyfin.Plugin.WhisperSubtitles.Configuration +{ + /// + /// Whisper model options. + /// + public enum WhisperModelType + { + /// + /// Tiny model (~10x speed, ~1GB VRAM, ~75MB download). + /// + Tiny, + + /// + /// Base model (~7x speed, ~1GB VRAM, ~140MB download). + /// + Base, + + /// + /// Small model (~4x speed, ~2GB VRAM, ~460MB download) - Recommended. + /// + Small, + + /// + /// Medium model (~2x speed, ~5GB VRAM, ~1.5GB download). + /// + Medium, + + /// + /// Turbo model (~8x speed, ~6GB VRAM, ~1.6GB download). + /// + Turbo, + + /// + /// Large model (Best quality, ~10GB VRAM, ~3GB download). + /// + Large + } + + /// + /// Plugin configuration. + /// + public class PluginConfiguration : BasePluginConfiguration + { + /// + /// Initializes a new instance of the class. + /// + public PluginConfiguration() + { + WhisperModel = WhisperModelType.Small; + TargetLanguage = "en"; + AIIdentifier = "whisper"; + TranslateToEnglish = false; + WordTimestamps = false; + ProcessOnLibraryScan = false; + SkipExisting = true; + RegenerateAI = false; + UseGPUAcceleration = true; + EnableMainMenu = true; + LibrariesToProcess = new List(); + FoldersToExclude = new List(); + } + + /// + /// Gets or sets the Whisper model to use. + /// + public WhisperModelType WhisperModel { get; set; } + + /// + /// Gets or sets the target language for subtitles. + /// + public string TargetLanguage { get; set; } + + /// + /// Gets or sets the AI identifier to add to subtitle filenames. + /// + public string AIIdentifier { get; set; } + + /// + /// Gets or sets a value indicating whether to translate to English. + /// + public bool TranslateToEnglish { get; set; } + + /// + /// Gets or sets a value indicating whether to enable word-level timestamps. + /// + public bool WordTimestamps { get; set; } + + /// + /// Gets or sets a value indicating whether to process on library scan. + /// + public bool ProcessOnLibraryScan { get; set; } + + /// + /// Gets or sets a value indicating whether to skip existing subtitles. + /// + public bool SkipExisting { get; set; } + + /// + /// Gets or sets a value indicating whether to regenerate AI subtitles. + /// + public bool RegenerateAI { get; set; } + + /// + /// Gets or sets a value indicating whether to enable GPU acceleration. + /// + public bool UseGPUAcceleration { get; set; } + + /// + /// Gets or sets a value indicating whether to show the plugin in the main menu navigation. + /// + public bool EnableMainMenu { get; set; } + + /// + /// Gets or sets the path to ffprobe. Leave empty for auto-detection. + /// + public string? FfprobePath { get; set; } + + /// + /// Gets or sets the list of library IDs to process. If empty, all libraries are processed. + /// + public List LibrariesToProcess { get; set; } + + /// + /// Gets or sets the list of folder paths to exclude from processing. + /// + public List FoldersToExclude { get; set; } + } +} diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/configPage.html b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/configPage.html new file mode 100644 index 0000000..7b705f3 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Configuration/configPage.html @@ -0,0 +1,615 @@ + + + + + Whisper Subtitles + + + + + + + + subtitles + Whisper Subtitles + + + Help + + + + + + + + + settings Model & Engine + + + Whisper Model + + Tiny (~10x speed, ~1GB VRAM, ~75MB) + Base (~7x speed, ~1GB VRAM, ~140MB) + Small (~4x speed, ~2GB VRAM, ~460MB) โ Recommended + Medium (~2x speed, ~5GB VRAM, ~1.5GB) + Turbo (~8x speed, ~6GB VRAM, ~1.6GB) + Large (Best quality, ~10GB VRAM, ~3GB) + + Larger models are more accurate but slower and use more VRAM. + + + + + + Download Model + + + + + + + Target Language + + Language code format (e.g., en, es, fr, de, ja, zh) + + + + AI Identifier + + Appends to file tag layout: movie.en.IDENTIFIER.srt. Leave blank to disable. + + + + + speed Acceleration + + + + + Enable CUDA (NVIDIA GPU) + + Uses the CUDA GPU binary instead of the CPU binary. Requires a deployed CUDA runtime binary and an NVIDIA GPU accessible inside the container (--gpus all). + + + + Runtime Hardware Status: detecting hardware profile... + + + CUDA Binary: not available + + + + FFprobe Path (optional) + + Path to ffprobe binary. Used to measure audio duration for chunked processing. Leave empty for auto-detect. + + + + + + + auto_awesome Library Automation + + + + + Process on Library Scan + + Trigger execution loop hooks instantly upon catalog folder discovery passes. + + + + + + Skip Existing Subtitles + + Bypass parsing logic for assets that possess pre-existing subtitle structures. + + + + + + Regenerate AI Subtitles + + Force overwrite parsing routines if existing targets match internal AI hashes. + + + + + + Translate to English + + Forced conversion tracking maps multilingual content blocks direct to English matrices. + + + + + + Enable Word-Level Timestamps + + Improves granular frame syncing timelines, but lengthens compute run-times. + + + + + + Show Whisper Subtitles in Main Menu + + Toggle the Whisper Subtitles entry in the server's main navigation. Save and refresh the client (or clear cache) to apply. + + + + Libraries to Process + + Loading target server libraries... + + Isolate parsing runs to targeted libraries. Empty sets default to checking all media folders. + + + + Folders to Exclude + + Exclude local absolute paths (one folder rule entry per line). + + + + + + memory System Status + + + System Execution Binary Absent + The underlying runtime binary is packed into local directories but must register access configurations via the system cache space. + + + Deploy Runtime Binary + + + + Manual target fallback location path: copy from ~/.config/jellyfin/plugins/Whisper Subtitles_*/whisper/ straight to ~/.cache/whisper-cpp/ + + + + + โ Whisper core environment validation passed. Ready. + + + + Binary System Node Point: unknown + Active Thread Allocations: - available logical processor units + + + + Analyzing platform performance metrics... + + + + + info About Module + + + Version- + Repository Sourcegithub.com/zakattack02/Whisper-Script + License ProfileMIT Open Source Agreement + Transcription Corewhisper.cpp standalone optimization layer (ggerganov) + + + + Generates high-fidelity subtitle tracking sheets automatically. Harnesses advanced local transcription engines without needing to route external data packets out to third party provider cloud spaces. + + + + + โ Unsaved Workspace Configuration Changes + + Save Configuration + + + + + + + + + + + diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Controllers/WhisperSubtitlesController.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Controllers/WhisperSubtitlesController.cs new file mode 100644 index 0000000..b036ca1 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Controllers/WhisperSubtitlesController.cs @@ -0,0 +1,341 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Jellyfin.Plugin.WhisperSubtitles.Configuration; +using Jellyfin.Plugin.WhisperSubtitles.Services; +using MediaBrowser.Controller.Entities; +using MediaBrowser.Controller.Library; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Logging; + +namespace Jellyfin.Plugin.WhisperSubtitles.Controllers +{ + [ApiController] + [Route("api/[controller]")] + [Produces("application/json")] + public class WhisperSubtitlesController : ControllerBase + { + private readonly ILogger _logger; + private readonly ILoggerFactory _loggerFactory; + private readonly ILibraryManager _libraryManager; + + public WhisperSubtitlesController( + ILogger logger, + ILoggerFactory loggerFactory, + ILibraryManager libraryManager) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _loggerFactory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory)); + _libraryManager = libraryManager ?? throw new ArgumentNullException(nameof(libraryManager)); + _logger.LogInformation("WhisperSubtitlesController initialized"); + } + + [HttpGet("Test")] + [AllowAnonymous] + public ActionResult Test() + { + _logger.LogInformation("Test endpoint called"); + return Ok("WhisperSubtitles controller is working!"); + } + + /// Returns item counts for the specified libraries. + [HttpGet("LibraryItemCounts")] + [AllowAnonymous] + public ActionResult LibraryItemCounts( + [FromQuery] string? libraryIds) + { + _logger.LogInformation("LibraryItemCounts called with ids: {Ids}", libraryIds); + + var ids = (libraryIds ?? string.Empty) + .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Select(id => Guid.TryParse(id, out var g) ? g : Guid.Empty) + .Where(g => g != Guid.Empty) + .ToHashSet(); + + var virtualFolders = _libraryManager.GetVirtualFolders(); + var libraries = new List(); + + foreach (var vf in virtualFolders) + { + if (!Guid.TryParse(vf.ItemId, out var folderId)) + continue; + + if (ids.Count > 0 && !ids.Contains(folderId)) + continue; + + var count = _libraryManager.GetCount(new InternalItemsQuery + { + ParentId = folderId, + Recursive = true, + IsVirtualItem = false + }); + + libraries.Add(new LibraryCountInfo + { + Id = vf.ItemId ?? string.Empty, + Name = vf.Name ?? "Unknown", + ItemCount = count + }); + } + + var total = libraries.Sum(l => l.ItemCount); + + var config = Plugin.Instance?.Configuration ?? new PluginConfiguration(); + var useGpu = config.UseGPUAcceleration; + var minutesPerItem = useGpu ? 2 : 35; + var estimatedMinutes = total * minutesPerItem; + + return Ok(new LibraryItemCountsResponse + { + TotalItemCount = total, + EstimatedMinutes = estimatedMinutes, + ProcessingMode = useGpu ? "GPU" : "CPU", + Libraries = libraries + }); + } + + /// Downloads (deploys from bundle) the requested Whisper model. + [HttpPost("DownloadModel")] + [Authorize(Policy = "RequiresElevation")] + public async Task> DownloadModel( + [FromBody] ModelDownloadRequest request, + CancellationToken cancellationToken = default) + { + _logger.LogInformation("DownloadModel called: {@Request}", request); + + if (request is null || string.IsNullOrWhiteSpace(request.ModelName)) + { + return BadRequest(new ModelDownloadResponse + { + Success = false, + Message = "ModelName is required" + }); + } + + var validModels = new[] { "Tiny", "Base", "Small", "Medium", "Turbo", "Large" }; + if (!Array.Exists(validModels, + m => m.Equals(request.ModelName, StringComparison.OrdinalIgnoreCase))) + { + return BadRequest(new ModelDownloadResponse + { + Success = false, + Message = $"Invalid model '{request.ModelName}'. Valid: {string.Join(", ", validModels)}" + }); + } + + try + { + using var svc = new WhisperService(_loggerFactory.CreateLogger()); + + // Ensure binary is in place before downloading the model + if (!await EnsureBinaryAsync(svc, cancellationToken)) + { + return StatusCode(500, new ModelDownloadResponse + { + Success = false, + Message = "Failed to deploy whisper binary. Check server logs." + }); + } + + var ok = await svc.DownloadModelAsync( + request.ModelName.ToLowerInvariant(), cancellationToken); + + if (!ok) + { + return StatusCode(500, new ModelDownloadResponse + { + Success = false, + Message = $"Failed to download model '{request.ModelName}'" + }); + } + + return Ok(new ModelDownloadResponse + { + Success = true, + Message = $"Model '{request.ModelName}' downloaded successfully" + }); + } + catch (OperationCanceledException) + { + return StatusCode(499, new ModelDownloadResponse + { + Success = false, + Message = "Download cancelled" + }); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error downloading model {Model}", request.ModelName); + return StatusCode(500, new ModelDownloadResponse + { + Success = false, + Message = "Download failed. Check server logs." + }); + } + } + + /// Checks if the whisper binary is available and ready to use. + [HttpGet("BinaryStatus")] + [AllowAnonymous] + public ActionResult BinaryStatus() + { + _logger.LogInformation("BinaryStatus called"); + + try + { + using var svc = new WhisperService(_loggerFactory.CreateLogger()); + + var isReady = svc.IsBinaryAvailable(); + + return Ok(new BinaryStatusResponse + { + IsReady = isReady, + BinaryPath = svc.BinaryPath, + GpuType = svc.DetectedGpuType, + CudaBinaryAvailable = svc.IsCudaBinaryAvailable + }); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error checking binary status"); + return Ok(new BinaryStatusResponse + { + IsReady = false, + Message = $"Error: {ex.Message}" + }); + } + } + + /// Deploys the bundled whisper binary from the plugin to the cache. + [HttpPost("InstallBinary")] + [Authorize(Policy = "RequiresElevation")] + public async Task> InstallBinary( + CancellationToken cancellationToken = default) + { + _logger.LogInformation("InstallBinary called"); + + try + { + using var svc = new WhisperService(_loggerFactory.CreateLogger()); + + if (!await EnsureBinaryAsync(svc, cancellationToken)) + { + return StatusCode(500, new BinaryInstallResponse + { + Success = false, + Message = "Failed to deploy whisper binary. Check server logs." + }); + } + + return Ok(new BinaryInstallResponse + { + Success = true, + Message = "whisper binary deployed successfully", + BinaryPath = svc.BinaryPath, + GpuType = svc.DetectedGpuType, + CudaBinaryAvailable = svc.IsCudaBinaryAvailable + }); + } + catch (OperationCanceledException) + { + return StatusCode(499, new BinaryInstallResponse + { + Success = false, + Message = "Installation cancelled" + }); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error installing binary"); + return StatusCode(500, new BinaryInstallResponse + { + Success = false, + Message = $"Error: {ex.Message}" + }); + } + } + + // โโ Private helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + /// + /// Calls the private EnsureBinaryAvailableAsync on WhisperService via reflection. + /// This indirection exists because IWhisperService intentionally doesn't expose it. + /// + private async Task EnsureBinaryAsync(WhisperService svc, CancellationToken ct) + { + var method = typeof(WhisperService).GetMethod( + "EnsureBinaryAvailableAsync", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, + null, + new[] { typeof(CancellationToken) }, + null); + + if (method is null) + { + _logger.LogError("EnsureBinaryAvailableAsync not found via reflection"); + return false; + } + + try + { + var task = method.Invoke(svc, new object[] { ct }) as Task; + return task is not null && await task; + } + catch (Exception ex) + { + _logger.LogError(ex, "Reflection call to EnsureBinaryAvailableAsync failed"); + return false; + } + } + } + + public class ModelDownloadRequest + { + [Required] + public string ModelName { get; set; } = string.Empty; + } + + public class ModelDownloadResponse + { + public bool Success { get; set; } + public string Message { get; set; } = string.Empty; + public string? ModelPath { get; set; } + } + + public class BinaryStatusResponse + { + public bool IsReady { get; set; } + public string Message { get; set; } = string.Empty; + public string? BinaryPath { get; set; } + public string? GpuType { get; set; } + public bool CudaBinaryAvailable { get; set; } + } + + public class BinaryInstallResponse + { + public bool Success { get; set; } + public string Message { get; set; } = string.Empty; + public string? BinaryPath { get; set; } + public string? GpuType { get; set; } + public bool CudaBinaryAvailable { get; set; } + } + + public class LibraryCountInfo + { + public string Id { get; set; } = string.Empty; + public string Name { get; set; } = string.Empty; + public int ItemCount { get; set; } + } + + public class LibraryItemCountsResponse + { + public int TotalItemCount { get; set; } + public int EstimatedMinutes { get; set; } + public string ProcessingMode { get; set; } = "CPU"; + public List Libraries { get; set; } = new(); + } +} \ No newline at end of file diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles.csproj b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles.csproj new file mode 100644 index 0000000..496af75 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles.csproj @@ -0,0 +1,60 @@ + + + + net9.0 + enable + enable + true + + bin\$(Configuration)\$(TargetFramework)\publish\ + + + + + + All + runtime + + + All + runtime + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Plugin.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Plugin.cs new file mode 100644 index 0000000..5626d14 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Plugin.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using Jellyfin.Plugin.WhisperSubtitles.Configuration; +using MediaBrowser.Common.Configuration; +using MediaBrowser.Common.Plugins; +using MediaBrowser.Model.Plugins; +using MediaBrowser.Model.Serialization; + +namespace Jellyfin.Plugin.WhisperSubtitles +{ + public class Plugin : BasePlugin, IHasWebPages + { + public Plugin(IApplicationPaths applicationPaths, IXmlSerializer xmlSerializer) + : base(applicationPaths, xmlSerializer) + { + Instance = this; + ApplicationPaths = applicationPaths; + } + + public static Plugin? Instance { get; private set; } + + public new IApplicationPaths ApplicationPaths { get; private set; } + + public override string Name => "Whisper Subtitles"; + + public override Guid Id => new Guid("a8b7c6d5-e4f3-4a5b-9c8d-7e6f5a4b3c2d"); + + public IEnumerable GetPages() + { + return new[] + { + new PluginPageInfo + { + Name = "Whisper Subtitles", + EmbeddedResourcePath = GetType().Namespace + ".Configuration.configPage.html", + MenuSection = "server", + MenuIcon = "subtitles", + EnableInMainMenu = Configuration.EnableMainMenu + } + }; + } + } +} diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/ISubtitleDetectionService.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/ISubtitleDetectionService.cs new file mode 100644 index 0000000..8cde6cc --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/ISubtitleDetectionService.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; + +namespace Jellyfin.Plugin.WhisperSubtitles.Services +{ + /// + /// Interface for subtitle detection service. + /// + public interface ISubtitleDetectionService + { + /// + /// Check if a video file has subtitles for a specific language. + /// + /// Path to the video file. + /// Language code (e.g., "en", "ja"). + /// AI identifier to look for (e.g., "whisper"). + /// True if subtitles exist, false otherwise. + bool HasSubtitles(string videoPath, string language, string? aiIdentifier = null); + + /// + /// Check if a video file has AI-generated subtitles. + /// + /// Path to the video file. + /// Language code. + /// AI identifier to check for. + /// True if AI-generated subtitles exist, false otherwise. + bool HasAISubtitles(string videoPath, string language, string aiIdentifier); + + /// + /// Get all subtitle files for a video. + /// + /// Path to the video file. + /// List of subtitle file paths. + IEnumerable GetSubtitleFiles(string videoPath); + + /// + /// Generate the subtitle output path for a video file. + /// + /// Path to the video file. + /// Language code. + /// AI identifier to include in filename. + /// Subtitle format (e.g., "srt", "vtt"). + /// Full path for the subtitle file. + string GetSubtitlePath(string videoPath, string language, string aiIdentifier, string format = "srt"); + } +} diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/IWhisperService.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/IWhisperService.cs new file mode 100644 index 0000000..711b04c --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/IWhisperService.cs @@ -0,0 +1,42 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Jellyfin.Plugin.WhisperSubtitles.Services +{ + /// + /// Interface for Whisper subtitle service. + /// + public interface IWhisperService + { + /// + /// Download a Whisper model. + /// + Task DownloadModelAsync(string modelName, CancellationToken cancellationToken = default); + + /// Whether a CUDA binary is available in the cache. + bool IsCudaBinaryAvailable { get; } + + /// + /// Generate subtitles for a video file. + /// + /// Path to the video file. + /// Path where subtitle file should be saved. + /// Whisper model to use. + /// Target language code. + /// Whether to translate to English. + /// Whether to include word-level timestamps. + /// Receives per-video progress 0.0โ1.0. + /// Cancellation token. + /// True if successful, false otherwise. + Task GenerateSubtitleAsync( + string videoPath, + string subtitlePath, + string modelName, + string language, + bool translate, + bool wordTimestamps, + IProgress? progress = null, + CancellationToken cancellationToken = default); + } +} diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/SubtitleDetectionService.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/SubtitleDetectionService.cs new file mode 100644 index 0000000..9cdbcfa --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/SubtitleDetectionService.cs @@ -0,0 +1,112 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.Extensions.Logging; + +namespace Jellyfin.Plugin.WhisperSubtitles.Services +{ + /// + /// Service for detecting and managing subtitle files. + /// + public class SubtitleDetectionService : ISubtitleDetectionService + { + private readonly ILogger _logger; + + /// + /// Initializes a new instance of the class. + /// + /// Logger instance. + public SubtitleDetectionService(ILogger logger) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + /// + public bool HasSubtitles(string videoPath, string language, string? aiIdentifier = null) + { + if (string.IsNullOrEmpty(videoPath) || !File.Exists(videoPath)) + { + return false; + } + + var subtitleFiles = GetSubtitleFiles(videoPath).ToList(); + + if (aiIdentifier != null) + { + // Look for AI-specific subtitles + return subtitleFiles.Any(f => f.Contains(language) && f.Contains(aiIdentifier)); + } + + // Look for any subtitles in the language + return subtitleFiles.Any(f => f.Contains(language)); + } + + /// + public bool HasAISubtitles(string videoPath, string language, string aiIdentifier) + { + if (string.IsNullOrEmpty(videoPath) || string.IsNullOrEmpty(aiIdentifier)) + { + return false; + } + + return HasSubtitles(videoPath, language, aiIdentifier); + } + + /// + public IEnumerable GetSubtitleFiles(string videoPath) + { + if (string.IsNullOrEmpty(videoPath) || !File.Exists(videoPath)) + { + return Enumerable.Empty(); + } + + var directory = Path.GetDirectoryName(videoPath); + if (string.IsNullOrEmpty(directory)) + { + return Enumerable.Empty(); + } + + var fileName = Path.GetFileNameWithoutExtension(videoPath); + var subtitleExtensions = new[] { ".srt", ".vtt", ".ass", ".ssa", ".sub" }; + + try + { + var subtitleFiles = Directory.GetFiles(directory, fileName + "*") + .Where(f => subtitleExtensions.Contains(Path.GetExtension(f).ToLowerInvariant())) + .ToList(); + + return subtitleFiles; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting subtitle files for {VideoPath}", videoPath); + return Enumerable.Empty(); + } + } + + /// + public string GetSubtitlePath(string videoPath, string language, string aiIdentifier, string format = "srt") + { + if (string.IsNullOrEmpty(videoPath)) + { + throw new ArgumentNullException(nameof(videoPath)); + } + + var directory = Path.GetDirectoryName(videoPath); + var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(videoPath); + + if (string.IsNullOrEmpty(directory)) + { + directory = Directory.GetCurrentDirectory(); + } + + // Generate subtitle filename: video.en.whisper.srt + var subtitleFileName = string.IsNullOrEmpty(aiIdentifier) + ? $"{fileNameWithoutExtension}.{language}.{format}" + : $"{fileNameWithoutExtension}.{language}.{aiIdentifier}.{format}"; + + return Path.Combine(directory, subtitleFileName); + } + } +} diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/WhisperBinaryManager.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/WhisperBinaryManager.cs new file mode 100644 index 0000000..9108a86 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/WhisperBinaryManager.cs @@ -0,0 +1,590 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Net.Http; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Jellyfin.Plugin.WhisperSubtitles.Services +{ + /// + /// Manages the whisper.cpp binary: discovery, deployment from plugin bundle, and testing. + /// + public class WhisperBinaryManager : IDisposable + { + // The filenames the build script produces inside whisper/{platform}/. + private const string BundledBinaryName = "whisper-whisper-cli"; + private const string CudaBundledBinaryName = "whisper-whisper-cli-cuda"; + + private readonly ILogger _logger; + private readonly HttpClient _httpClient; + private readonly string _binaryPath; + private readonly string _cudaBinaryPath; + private readonly string _cudaLibDir; + private readonly string _downloadPath; + private readonly string? _jellyfinFFmpegPath; + private string? _detectedGPUType; + private bool _disposed; + + /// + /// Initialises a new instance of . + /// + public WhisperBinaryManager(ILogger logger) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _httpClient = new HttpClient(new SocketsHttpHandler + { + AutomaticDecompression = System.Net.DecompressionMethods.GZip + }) + { + Timeout = TimeSpan.FromMinutes(30) + }; + _httpClient.DefaultRequestHeaders.Add("User-Agent", "Jellyfin-Whisper-Plugin"); + + _jellyfinFFmpegPath = FindJellyfinFFmpeg(); + _ffprobePath = FindFfprobe(); + _detectedGPUType = DetectGPU(); + + // โโ Determine cache directory โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + var cacheDir = Environment.GetEnvironmentVariable("JELLYFIN_CACHE_DIR"); + if (string.IsNullOrEmpty(cacheDir)) + { + var home = Environment.GetEnvironmentVariable("HOME") + ?? Path.GetTempPath(); + cacheDir = Path.Combine(home, ".cache"); + } + + var whisperDir = Path.Combine(cacheDir, "whisper-cpp"); + _downloadPath = whisperDir; + _binaryPath = Path.Combine(whisperDir, BundledBinaryName); + _cudaBinaryPath = Path.Combine(whisperDir, CudaBundledBinaryName); + _cudaLibDir = whisperDir; + + _logger.LogInformation("=== WhisperBinaryManager Init ==="); + _logger.LogInformation("JELLYFIN_CACHE_DIR env : {Env}", Environment.GetEnvironmentVariable("JELLYFIN_CACHE_DIR") ?? "(not set)"); + _logger.LogInformation("Resolved cache dir : {CacheDir}", cacheDir); + _logger.LogInformation("Whisper cache directory : {WhisperDir}", whisperDir); + _logger.LogInformation("Expected binary path : {BinaryPath}", _binaryPath); + _logger.LogInformation("Jellyfin FFmpeg : {Path}", _jellyfinFFmpegPath ?? "not found"); + _logger.LogInformation("Jellyfin FFprobe : {Path}", _ffprobePath ?? "not found"); + _logger.LogInformation("Detected GPU : {GPU}", _detectedGPUType ?? "none (CPU only)"); + _logger.LogInformation("====================================="); + + try + { + Directory.CreateDirectory(whisperDir); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to create whisper cache directory: {Dir}", whisperDir); + } + } + + // โโ Public surface โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + /// Gets the path where the CPU binary is expected to live in the cache. + public string BinaryPath => _binaryPath; + + /// Gets the path where the CUDA binary is expected to live in the cache. + public string CudaBinaryPath => _cudaBinaryPath; + + /// Gets the directory containing the bundled CUDA .so libraries. + public string CudaLibDir => _cudaLibDir; + + /// True if the CUDA binary is present in the cache. + public bool IsCudaBinaryAvailable => File.Exists(_cudaBinaryPath); + + /// Gets the detected GPU type string, or null for CPU-only. + public string? DetectedGPUType => _detectedGPUType; + + /// Gets Jellyfin's bundled FFmpeg path, if found. + public string? JellyfinFFmpegPath => _jellyfinFFmpegPath; + + /// Returns true if the binary is present and executable in the cache. + public bool IsBinaryAvailable() + { + // Check for any of the known binary names (supports migration across versions) + var binaryNames = new[] { "whisper-whisper-cli", "whisper-cli", "main" }; + + _logger.LogInformation("Checking for whisper binary in: {Path}", _downloadPath); + + foreach (var binaryName in binaryNames) + { + var candidatePath = Path.Combine(_downloadPath, binaryName); + var exists = File.Exists(candidatePath); + _logger.LogDebug(" Checking {BinaryName}: {Path} โ {Exists}", + binaryName, candidatePath, exists ? "FOUND" : "not found"); + + if (exists) + { + EnsureExecutable(candidatePath); + _logger.LogInformation("โ Whisper binary found at {Path}", candidatePath); + return true; + } + } + + _logger.LogWarning("โ Whisper binary NOT found in {Path}", _downloadPath); + _logger.LogWarning(" Expected one of: {Names}", string.Join(", ", binaryNames)); + + // List what files actually exist in the directory for debugging + try + { + if (Directory.Exists(_downloadPath)) + { + var files = Directory.GetFiles(_downloadPath); + _logger.LogWarning(" Files in {Path}: {Files}", + _downloadPath, + files.Length > 0 ? string.Join(", ", files.Select(Path.GetFileName)) : "(empty)"); + } + else + { + _logger.LogWarning(" Directory does not exist: {Path}", _downloadPath); + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Could not list directory contents: {Path}", _downloadPath); + } + + return false; + } + + /// + /// Deploys the bundled binary from the plugin's installation directory into the cache. + /// This is the only "download" path โ we ship the binary, we just need to copy it. + /// Skips deployment if a binary already exists in the cache. + /// + public async Task DownloadBinaryAsync(CancellationToken cancellationToken = default) + { + try + { + // Skip deployment if any binary already exists + if (IsBinaryAvailable()) + { + _logger.LogInformation("Binary already available in cache, skipping deployment"); + return true; + } + + _logger.LogInformation("Deploying bundled whisper binary from plugin directory..."); + + var source = FindBundledBinary(); + if (source is null) + { + _logger.LogError( + "Bundled binary not found inside plugin folder. " + + "Expected '{Name}' inside whisper/{Platform}/ sub-directory.", + BundledBinaryName, GetPlatformString()); + return false; + } + + _logger.LogInformation("Copying {Source} โ {Dest}", source, _binaryPath); + File.Copy(source, _binaryPath, overwrite: true); + EnsureExecutable(_binaryPath); + + // Deploy CUDA binary and .so files if present in the bundle + var cudaSource = FindBundledCudaBinary(); + if (cudaSource is not null) + { + _logger.LogInformation("Copying CUDA binary {Source} โ {Dest}", cudaSource, _cudaBinaryPath); + File.Copy(cudaSource, _cudaBinaryPath, overwrite: true); + EnsureExecutable(_cudaBinaryPath); + + foreach (var lib in new[] { "libcudart.so.12", "libcublas.so.12", "libcublasLt.so.12" }) + { + var libSource = Path.Combine(Path.GetDirectoryName(cudaSource) ?? string.Empty, lib); + if (File.Exists(libSource)) + { + var libDest = Path.Combine(_cudaLibDir, lib); + _logger.LogInformation("Copying CUDA lib {Source} โ {Dest}", libSource, libDest); + File.Copy(libSource, libDest, overwrite: true); + } + } + } + else + { + _logger.LogInformation("No CUDA binary bundled โ GPU acceleration will not be available"); + } + + return await TestBinaryAsync(cancellationToken); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to deploy bundled whisper binary"); + return false; + } + } + + /// Runs the binary with --help to verify it starts correctly. + public async Task TestBinaryAsync(CancellationToken cancellationToken = default) + { + if (!IsBinaryAvailable()) + return false; + + try + { + using var process = Process.Start(new ProcessStartInfo + { + FileName = _binaryPath, + Arguments = "--help", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + })!; + + await process.WaitForExitAsync(cancellationToken); + + // whisper-cli --help exits with 0; some versions exit 1 but still print usage. + // Accept both โ what matters is it ran without an OS-level failure. + var success = process.ExitCode == 0 || process.ExitCode == 1; + _logger.LogInformation("Binary test {Result} (exit code {Code})", + success ? "passed" : "FAILED", process.ExitCode); + return success; + } + catch (Exception ex) + { + _logger.LogError(ex, "Exception while testing whisper binary"); + return false; + } + } + + // โโ Private helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + /// + /// Locates the bundled CPU binary inside the plugin's own installation folder. + /// + private string? FindBundledBinary() + { + var assemblyDir = Path.GetDirectoryName( + typeof(WhisperBinaryManager).Assembly.Location); + + if (string.IsNullOrEmpty(assemblyDir)) + { + _logger.LogError("Could not determine assembly directory"); + return null; + } + + var platform = GetPlatformString(); + var binaryNames = new[] { "whisper-whisper-cli", "whisper-cli", "main" }; + + foreach (var binaryName in binaryNames) + { + var candidate = Path.Combine(assemblyDir, "whisper", platform, binaryName); + if (File.Exists(candidate)) + { + _logger.LogInformation("Located bundled CPU binary: {Path}", candidate); + return candidate; + } + } + + foreach (var binaryName in binaryNames) + { + var candidate = Path.Combine(assemblyDir, platform, binaryName); + if (File.Exists(candidate)) + { + _logger.LogInformation("Located bundled CPU binary (fallback): {Path}", candidate); + return candidate; + } + } + + return null; + } + + /// Locates the bundled CUDA binary and .so files in the plugin folder. + private string? FindBundledCudaBinary() + { + var assemblyDir = Path.GetDirectoryName( + typeof(WhisperBinaryManager).Assembly.Location); + + if (string.IsNullOrEmpty(assemblyDir)) + return null; + + var platform = GetPlatformString(); + + var candidate = Path.Combine(assemblyDir, "whisper", platform, CudaBundledBinaryName); + if (File.Exists(candidate)) + { + _logger.LogInformation("Located bundled CUDA binary: {Path}", candidate); + return candidate; + } + + // Fallback: flat layout + candidate = Path.Combine(assemblyDir, platform, CudaBundledBinaryName); + if (File.Exists(candidate)) + { + _logger.LogInformation("Located bundled CUDA binary (fallback): {Path}", candidate); + return candidate; + } + + return null; + } + + private string GetPlatformString() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return RuntimeInformation.ProcessArchitecture switch + { + Architecture.X64 => "linux-x64", + Architecture.Arm64 => "linux-arm64", + _ => "linux-x64" + }; + } + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return "windows-x64"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return RuntimeInformation.ProcessArchitecture switch + { + Architecture.Arm64 => "macos-arm64", + _ => "macos-x64" + }; + } + return "linux-x64"; // safe default for Docker + } + + private static void EnsureExecutable(string filePath) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return; + + try + { + var fi = new FileInfo(filePath); + // Set rwxr-xr-x + fi.UnixFileMode = + UnixFileMode.UserRead | UnixFileMode.UserWrite | UnixFileMode.UserExecute | + UnixFileMode.GroupRead | UnixFileMode.GroupExecute | + UnixFileMode.OtherRead | UnixFileMode.OtherExecute; + } + catch + { + // Fallback to chmod subprocess + try + { + using var p = Process.Start(new ProcessStartInfo + { + FileName = "chmod", + Arguments = $"+x \"{filePath}\"", + UseShellExecute = false, + CreateNoWindow = true + }); + p?.WaitForExit(); + } + catch { /* best-effort */ } + } + } + + // โโ GPU detection โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + private string? DetectGPU() + { + if (CheckNvidiaGPU()) return "cuda"; + if (CheckVulkanGPU()) return "vulkan"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) return "metal"; + return null; + } + + private bool CheckNvidiaGPU() + { + try + { + using var p = Process.Start(new ProcessStartInfo + { + FileName = "nvidia-smi", + Arguments = "--query-gpu=name --format=csv,noheader", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + }); + if (p is null) return false; + var output = p.StandardOutput.ReadToEnd(); + p.WaitForExit(); + if (p.ExitCode == 0 && !string.IsNullOrWhiteSpace(output)) + { + _logger.LogInformation("NVIDIA GPU detected: {GPU}", output.Trim()); + return true; + } + } + catch { /* nvidia-smi not present */ } + return false; + } + + private bool CheckVulkanGPU() + { + try + { + using var p = Process.Start(new ProcessStartInfo + { + FileName = "vulkaninfo", + Arguments = "--summary", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + }); + if (p is null) return false; + var output = p.StandardOutput.ReadToEnd(); + p.WaitForExit(); + if (p.ExitCode == 0 && output.Contains("deviceName")) + { + _logger.LogInformation("Vulkan GPU detected"); + return true; + } + } + catch { /* vulkaninfo not present */ } + return false; + } + + // โโ FFmpeg discovery โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + private string? FindJellyfinFFmpeg() + { + var candidates = new[] + { + "/usr/lib/jellyfin-ffmpeg/ffmpeg", + "/usr/lib/jellyfin-ffmpeg5/ffmpeg", + "/usr/lib/jellyfin-ffmpeg6/ffmpeg", + "/jellyfin/ffmpeg", + "/config/ffmpeg/ffmpeg", + "ffmpeg" + }; + + foreach (var path in candidates) + { + try + { + if (File.Exists(path)) + { + _logger.LogInformation("Found Jellyfin FFmpeg: {Path}", path); + return path; + } + } + catch { /* permission errors on some paths */ } + } + + // Try `which ffmpeg` + try + { + using var p = Process.Start(new ProcessStartInfo + { + FileName = "which", + Arguments = "ffmpeg", + RedirectStandardOutput = true, + UseShellExecute = false, + CreateNoWindow = true + }); + if (p is not null) + { + var result = p.StandardOutput.ReadToEnd().Trim(); + p.WaitForExit(); + if (!string.IsNullOrEmpty(result) && File.Exists(result)) + { + _logger.LogInformation("Found FFmpeg via which: {Path}", result); + return result; + } + } + } + catch { /* which not available */ } + + _logger.LogWarning("FFmpeg not found; whisper.cpp will use built-in audio handling"); + return null; + } + + private string? _ffprobePath; + + /// Gets the resolved ffprobe path, or null if not found. + public string? FfprobePath => _ffprobePath; + + private string? FindFfprobe() + { + // 1. Check for ffprobe next to the ffmpeg binary we already found + if (_jellyfinFFmpegPath is not null) + { + var dir = Path.GetDirectoryName(_jellyfinFFmpegPath); + if (dir is not null) + { + var candidate = Path.Combine(dir, "ffprobe"); + if (File.Exists(candidate)) + { + _logger.LogInformation("Found FFprobe next to FFmpeg: {Path}", candidate); + return candidate; + } + } + } + + // 2. Known container paths + var candidates = new[] + { + "/usr/lib/jellyfin-ffmpeg/ffprobe", + "/usr/lib/jellyfin-ffmpeg5/ffprobe", + "/usr/lib/jellyfin-ffmpeg6/ffprobe", + "/jellyfin/ffprobe", + "/config/ffprobe/ffprobe", + "ffprobe" + }; + + foreach (var path in candidates) + { + try + { + if (File.Exists(path)) + { + _logger.LogInformation("Found FFprobe: {Path}", path); + return path; + } + } + catch { } + } + + // 3. Try `which ffprobe` + try + { + using var p = Process.Start(new ProcessStartInfo + { + FileName = "which", + Arguments = "ffprobe", + RedirectStandardOutput = true, + UseShellExecute = false, + CreateNoWindow = true + }); + if (p is not null) + { + var result = p.StandardOutput.ReadToEnd().Trim(); + p.WaitForExit(); + if (!string.IsNullOrEmpty(result) && File.Exists(result)) + { + _logger.LogInformation("Found FFprobe via which: {Path}", result); + return result; + } + } + } + catch { } + + _logger.LogWarning("FFprobe not found; chunk duration detection will be unavailable"); + return null; + } + + // โโ IDisposable โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) return; + if (disposing) _httpClient?.Dispose(); + _disposed = true; + } + + ~WhisperBinaryManager() => Dispose(false); + } +} \ No newline at end of file diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/WhisperService.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/WhisperService.cs new file mode 100644 index 0000000..e9549f0 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Services/WhisperService.cs @@ -0,0 +1,737 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using System.Net.Http; +using Microsoft.Extensions.Logging; + +namespace Jellyfin.Plugin.WhisperSubtitles.Services +{ + /// + /// Generates subtitles using the bundled whisper.cpp binary. + /// + public class WhisperService : IWhisperService, IDisposable + { + private readonly ILogger _logger; + private readonly HttpClient _httpClient; + private readonly string _modelPath; + private readonly WhisperBinaryManager _binaryManager; + private bool _disposed; + + public WhisperService(ILogger logger) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _httpClient = new HttpClient(); + + // No pluginPath parameter โ WhisperBinaryManager uses assembly location now. + _binaryManager = new WhisperBinaryManager(logger); + + // Model storage + var cacheDir = Environment.GetEnvironmentVariable("JELLYFIN_CACHE_DIR"); + if (string.IsNullOrEmpty(cacheDir)) + { + var home = Environment.GetEnvironmentVariable("HOME") ?? Path.GetTempPath(); + cacheDir = Path.Combine(home, ".cache"); + } + + _modelPath = Path.Combine(cacheDir, "whisper"); + + try { Directory.CreateDirectory(_modelPath); } + catch (Exception ex) { _logger.LogError(ex, "Failed to create model dir: {Dir}", _modelPath); } + + _logger.LogInformation("WhisperService ready. Model path: {Path}", _modelPath); + + if (!_binaryManager.IsBinaryAvailable()) + _logger.LogWarning("Whisper binary not in cache โ will deploy from plugin bundle on first use."); + } + + public string BinaryPath => _binaryManager.BinaryPath; + public string? DetectedGpuType => _binaryManager.DetectedGPUType; + public bool IsCudaBinaryAvailable => _binaryManager.IsCudaBinaryAvailable; + + /// Returns true if the whisper binary is available and ready to use. + public bool IsBinaryAvailable() => _binaryManager.IsBinaryAvailable(); + + // โโ IWhisperService โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + public async Task DownloadModelAsync( + string modelName, + CancellationToken cancellationToken = default) + { + try + { + var modelFile = GetModelPath(modelName); + if (File.Exists(modelFile)) + { + _logger.LogInformation("Model '{Model}' already cached at {Path}", modelName, modelFile); + return true; + } + + var url = GetModelDownloadUrl(modelName); + if (url is null) + { + _logger.LogError("Unknown model name: {Model}", modelName); + return false; + } + + _logger.LogInformation("Downloading model '{Model}' from {Url}", modelName, url); + + using var response = await _httpClient.GetAsync( + url, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + response.EnsureSuccessStatusCode(); + + var totalBytes = response.Content.Headers.ContentLength ?? 0; + _logger.LogInformation("Model size: {MB} MB", totalBytes / 1024 / 1024); + + await using var src = await response.Content.ReadAsStreamAsync(cancellationToken); + await using var dest = new FileStream( + modelFile, FileMode.Create, FileAccess.Write, FileShare.None, 65536, useAsync: true); + + var buffer = new byte[65536]; + long totalRead = 0; + int bytesRead; + var lastLog = DateTime.UtcNow; + + while ((bytesRead = await src.ReadAsync(buffer, 0, buffer.Length, cancellationToken)) > 0) + { + await dest.WriteAsync(buffer, 0, bytesRead, cancellationToken); + totalRead += bytesRead; + + if ((DateTime.UtcNow - lastLog).TotalSeconds >= 30) + { + var pct = totalBytes > 0 ? totalRead * 100.0 / totalBytes : 0; + _logger.LogInformation("Download: {Pct:F1}% ({Read}/{Total} MB)", + pct, totalRead / 1024 / 1024, totalBytes / 1024 / 1024); + lastLog = DateTime.UtcNow; + } + } + + _logger.LogInformation("Model '{Model}' downloaded to {Path}", modelName, modelFile); + return true; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to download model '{Model}'", modelName); + return false; + } + } + + private const int ChunkDurationMs = 30 * 60 * 1000; // 30 min per chunk to stay under ~2GB RAM + + public async Task GenerateSubtitleAsync( + string videoPath, + string outputPath, + string modelName, + string language, + bool translate, + bool wordTimestamps, + IProgress? progress = null, + CancellationToken cancellationToken = default) + { + var config = Plugin.Instance?.Configuration ?? new Configuration.PluginConfiguration(); + + // 1. Ensure binary is ready + if (!await EnsureBinaryAvailableAsync(cancellationToken)) + { + _logger.LogError("Whisper binary not available โ cannot generate subtitles"); + return false; + } + + if (!File.Exists(videoPath)) + { + _logger.LogError("Video not found: {Path}", videoPath); + return false; + } + + // 2. Ensure model is downloaded + var modelFile = await EnsureModelDownloadedAsync(modelName, cancellationToken); + if (modelFile is null) + { + _logger.LogError("Could not obtain model file for '{Model}'", modelName); + return false; + } + + var outputDir = Path.GetDirectoryName(outputPath) ?? Directory.GetCurrentDirectory(); + var outputStem = Path.GetFileNameWithoutExtension(outputPath); + + // โโ Extract audio from video โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + var tempWav = Path.Combine( + Path.GetTempPath(), + $"whisper_{Guid.NewGuid():N}.wav"); + + try + { + if (!await ExtractAudioAsync(videoPath, tempWav, cancellationToken)) + { + _logger.LogError("Failed to extract audio from {Video}", videoPath); + return false; + } + + // Determine which binary to use + var gpuType = config.UseGPUAcceleration ? _binaryManager.DetectedGPUType : null; + string binaryPath; + + if (gpuType == "cuda" && _binaryManager.IsCudaBinaryAvailable) + { + binaryPath = _binaryManager.CudaBinaryPath; + _logger.LogInformation("Using CUDA binary at {Path}", binaryPath); + } + else + { + binaryPath = _binaryManager.BinaryPath; + if (config.UseGPUAcceleration && gpuType is not null) + { + _logger.LogWarning("CUDA binary not available, falling back to CPU binary"); + gpuType = null; + } + } + + var acceleration = gpuType is not null ? $"GPU ({gpuType})" : "CPU"; + + _logger.LogInformation( + "Generating subtitles: video={Video}, model={Model}, lang={Lang}, translate={T}, accel={A}", + videoPath, modelName, language, translate, acceleration); + + var wavDurationMs = await GetWavDurationMsAsync(tempWav, cancellationToken); + + if (wavDurationMs <= ChunkDurationMs) + { + if (!await RunWhisperCli( + binaryPath, modelFile, tempWav, outputDir, outputStem, + language, translate, wordTimestamps, gpuType, cancellationToken)) + return false; + + progress?.Report(1.0); + } + else + { + var chunks = await SplitWavAsync(tempWav, ChunkDurationMs, cancellationToken); + if (chunks.Count == 0) + { + _logger.LogError("Audio splitting failed โ aborting subtitle generation"); + return false; + } + try + { + var mergedSrt = new StringBuilder(); + int segmentOffset = 0; + + for (int i = 0; i < chunks.Count; i++) + { + var chunkStem = $"{outputStem}.part{i:D3}"; + + if (!await RunWhisperCli( + binaryPath, modelFile, chunks[i], outputDir, chunkStem, + language, translate, wordTimestamps, gpuType, cancellationToken)) + return false; + + var chunkSrtPath = Path.Combine(outputDir, $"{chunkStem}.srt"); + segmentOffset = MergeSrtInto(chunkSrtPath, mergedSrt, segmentOffset); + File.Delete(chunkSrtPath); + + progress?.Report((double)(i + 1) / chunks.Count); + } + + await File.WriteAllTextAsync(outputPath, mergedSrt.ToString(), cancellationToken); + } + finally + { + foreach (var cp in chunks) + { + try { if (File.Exists(cp)) File.Delete(cp); } + catch { } + } + } + } + + if (!File.Exists(outputPath)) + { + _logger.LogError("Subtitle file not created: {Path}", outputPath); + return false; + } + + _logger.LogInformation("Subtitles written: {Path} ({Bytes} bytes)", + outputPath, new FileInfo(outputPath).Length); + return true; + } + catch (OperationCanceledException) + { + _logger.LogWarning("Subtitle generation cancelled for {Video}", videoPath); + return false; + } + catch (Exception ex) + { + _logger.LogError(ex, "Unexpected error generating subtitles for {Video}", videoPath); + return false; + } + finally + { + try { if (File.Exists(tempWav)) File.Delete(tempWav); } catch { } + } + } + + private async Task RunWhisperCli( + string binaryPath, string modelFile, string wavPath, string outputDir, string outputStem, + string language, bool translate, bool wordTimestamps, string? gpuType, + CancellationToken cancellationToken) + { + var args = BuildArguments( + modelFile, wavPath, outputDir, outputStem, + language, translate, wordTimestamps, gpuType); + + _logger.LogInformation("Command: {Binary} {Args}", binaryPath, args); + + var psi = new ProcessStartInfo + { + FileName = binaryPath, + Arguments = args, + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true, + WorkingDirectory = outputDir + }; + + // Set LD_LIBRARY_PATH for CUDA binary to find bundled .so files + if (binaryPath == _binaryManager.CudaBinaryPath) + { + psi.EnvironmentVariables["LD_LIBRARY_PATH"] = _binaryManager.CudaLibDir; + } + + using var process = Process.Start(psi); + if (process is null) + { + _logger.LogError("Failed to start whisper-whisper-cli process"); + return false; + } + + var stdout = new StringBuilder(); + var stderr = new StringBuilder(); + + using var outDone = new ManualResetEventSlim(false); + using var errDone = new ManualResetEventSlim(false); + + process.OutputDataReceived += (_, e) => + { + if (e.Data is null) { outDone.Set(); return; } + stdout.AppendLine(e.Data); + _logger.LogDebug("whisper: {L}", e.Data); + }; + + process.ErrorDataReceived += (_, e) => + { + if (e.Data is null) { errDone.Set(); return; } + stderr.AppendLine(e.Data); + if (e.Data.Contains("CUDA") || e.Data.Contains("GPU") || + e.Data.Contains("progress") || e.Data.Contains("processing")) + _logger.LogInformation("whisper: {L}", e.Data); + else + _logger.LogDebug("whisper: {L}", e.Data); + }; + + process.BeginOutputReadLine(); + process.BeginErrorReadLine(); + + using var whisperTimeout = new CancellationTokenSource(TimeSpan.FromMinutes(60)); + using var whisperCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, whisperTimeout.Token); + try + { + await process.WaitForExitAsync(whisperCts.Token); + } + catch (OperationCanceledException) when (whisperTimeout.IsCancellationRequested) + { + _logger.LogError("whisper-cli timed out after 60 minutes, killing process"); + if (!process.HasExited) process.Kill(entireProcessTree: true); + return false; + } + outDone.Wait(5_000); + errDone.Wait(5_000); + + if (process.ExitCode != 0) + { + _logger.LogError( + "whisper-whisper-cli exited {Code}.\nstderr: {Err}\nstdout: {Out}", + process.ExitCode, + stderr.Length > 0 ? stderr.ToString() : "(empty)", + stdout.Length > 0 ? stdout.ToString() : "(empty)"); + return false; + } + + var outputPath = Path.Combine(outputDir, $"{outputStem}.srt"); + if (!File.Exists(outputPath)) + { + _logger.LogError( + "Subtitle file not created: {Path}\nExit code: {Code}\nstderr: {Err}\nstdout: {Out}", + outputPath, + process.ExitCode, + stderr.Length > 0 ? stderr.ToString() : "(empty)", + stdout.Length > 0 ? stdout.ToString() : "(empty)"); + return false; + } + + return true; + } + + private async Task GetWavDurationMsAsync(string wavPath, CancellationToken ct) + { + var ffprobe = Plugin.Instance?.Configuration.FfprobePath; + if (string.IsNullOrEmpty(ffprobe)) + ffprobe = _binaryManager.FfprobePath ?? "ffprobe"; + + try + { + var psi = new ProcessStartInfo + { + FileName = ffprobe, + Arguments = $"-v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 \"{wavPath}\"", + RedirectStandardOutput = true, + UseShellExecute = false, + CreateNoWindow = true + }; + + using var proc = Process.Start(psi); + if (proc is null) return 0; + + using var ffprobeTimeout = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + using var ffprobeCts = CancellationTokenSource.CreateLinkedTokenSource(ct, ffprobeTimeout.Token); + var output = await proc.StandardOutput.ReadToEndAsync(ffprobeCts.Token); + try { await proc.WaitForExitAsync(ffprobeCts.Token); } + catch (OperationCanceledException) when (ffprobeTimeout.IsCancellationRequested) + { + if (!proc.HasExited) proc.Kill(entireProcessTree: true); + _logger.LogWarning("ffprobe timed out"); + return 0; + } + + if (double.TryParse(output.Trim(), NumberStyles.Any, CultureInfo.InvariantCulture, out var seconds)) + return (int)(seconds * 1000); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to detect audio duration via {Ffprobe}", ffprobe); + } + + return 0; + } + + private async Task> SplitWavAsync(string wavPath, int chunkMs, CancellationToken ct) + { + var ffmpeg = _binaryManager.JellyfinFFmpegPath ?? "ffmpeg"; + var chunkSec = chunkMs / 1000.0; + var dir = Path.GetDirectoryName(wavPath) ?? Path.GetTempPath(); + var prefix = $"whisper_chunk_{Guid.NewGuid():N}_"; + var pattern = Path.Combine(dir, $"{prefix}%03d.wav"); + + _logger.LogInformation("Splitting audio ({Chunk}s chunks): {Wav} โ {Pattern}", + chunkSec, wavPath, pattern); + + var psi = new ProcessStartInfo + { + FileName = ffmpeg, + Arguments = $"-i \"{wavPath}\" -f segment -segment_time {chunkSec} -c:a pcm_s16le -ar 16000 -ac 1 \"{pattern}\" -y -loglevel error", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + }; + + using var proc = Process.Start(psi); + if (proc is null) + { + _logger.LogError("Failed to start FFmpeg for audio splitting"); + return new List(); + } + + using var splitTimeout = new CancellationTokenSource(TimeSpan.FromMinutes(30)); + using var splitCts = CancellationTokenSource.CreateLinkedTokenSource(ct, splitTimeout.Token); + try + { + var err = await proc.StandardError.ReadToEndAsync(splitCts.Token); + await proc.WaitForExitAsync(splitCts.Token); + } + catch (OperationCanceledException) when (splitTimeout.IsCancellationRequested) + { + _logger.LogError("FFmpeg split timed out after 30 minutes"); + if (!proc.HasExited) proc.Kill(entireProcessTree: true); + return new List(); + } + + if (proc.ExitCode != 0) + { + _logger.LogError("FFmpeg split failed ({Code}): {Err}", proc.ExitCode, proc.StandardError.ReadToEnd()); + return new List(); + } + + List files; + try + { + files = Directory.GetFiles(dir, $"{prefix}*.wav") + .OrderBy(f => f) + .ToList(); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to list split chunks in {Dir}", dir); + return new List(); + } + + if (files.Count == 0) + { + _logger.LogWarning("FFmpeg split produced no files, using original"); + return new List { wavPath }; + } + + _logger.LogInformation("Audio split into {Count} chunk(s)", files.Count); + return files; + } + + private static int MergeSrtInto(string chunkSrtPath, StringBuilder merged, int offset) + { + var lines = File.ReadAllLines(chunkSrtPath); + bool expectNumber = true; // start of file โ expect segment number + int localMax = 0; + + foreach (var line in lines) + { + if (expectNumber && line.Length > 0 && int.TryParse(line, out var n)) + { + merged.AppendLine((n + offset).ToString()); + if (n > localMax) localMax = n; + expectNumber = false; + } + else + { + merged.AppendLine(line); + if (line.Length == 0) + expectNumber = true; // blank line โ next line is segment number + } + } + + if (lines.Length > 0 && lines[^1] != "") + merged.AppendLine(); + + return offset + localMax; + } + + // โโ Private helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + /// + /// Constructs the argument string for whisper-whisper-cli. + /// Each path is individually quoted; the whole string is NOT wrapped in quotes. + /// + private string BuildArguments( + string modelFile, + string videoPath, + string outputDir, + string outputStem, + string language, + bool translate, + bool wordTimestamps, + string? gpuType) + { + var sb = new StringBuilder(); + + // Model and input โ paths individually quoted + sb.Append($"-m \"{modelFile}\" "); + sb.Append($"-f \"{videoPath}\" "); + + // Language + sb.Append($"-l \"{language}\" "); + + // Output โ SRT format, stem only (no extension) + // Output file is created in the process WorkingDirectory (set to outputDir) + sb.Append("-osrt "); + sb.Append($"-of \"{outputStem}\" "); + + if (translate) + sb.Append("-tr "); + + if (wordTimestamps) + sb.Append("-ml 1 "); + + // Thread cap + var threads = Math.Min(Environment.ProcessorCount, 16); + sb.Append($"-t {threads} "); + + // GPU offload (whisper-cli only supports -dev N, not -ngl) + if (gpuType is not null) + { + sb.Append("-dev 0 "); + _logger.LogInformation("{GPU} acceleration enabled", gpuType); + } + else + { + // whisper-cli defaults to use_gpu=true; explicitly disable it + sb.Append("-ng "); + _logger.LogInformation("CPU-only processing"); + } + + return sb.ToString(); + } + + private async Task EnsureBinaryAvailableAsync(CancellationToken cancellationToken) + { + if (_binaryManager.IsBinaryAvailable()) + return true; + + _logger.LogInformation("Binary not in cache โ deploying from plugin bundle..."); + var deployed = await _binaryManager.DownloadBinaryAsync(cancellationToken); + + if (!deployed) + { + _logger.LogError("Failed to deploy whisper binary from plugin bundle"); + return false; + } + + var tested = await _binaryManager.TestBinaryAsync(cancellationToken); + if (!tested) + _logger.LogError("Deployed binary failed self-test"); + + return tested; + } + + private async Task ExtractAudioAsync( + string videoPath, string wavPath, CancellationToken cancellationToken) + { + var ffmpeg = _binaryManager.JellyfinFFmpegPath ?? "ffmpeg"; + + _logger.LogInformation("Extracting audio: {Video} โ {Wav}", videoPath, wavPath); + + try + { + var psi = new ProcessStartInfo + { + FileName = ffmpeg, + Arguments = $"-i \"{videoPath}\" -ar 16000 -ac 1 -c:a pcm_s16le -f wav \"{wavPath}\" -y -loglevel error", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + }; + + using var process = Process.Start(psi); + if (process is null) + { + _logger.LogError("Failed to start FFmpeg process"); + return false; + } + + using var extractTimeout = new CancellationTokenSource(TimeSpan.FromMinutes(30)); + using var extractCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, extractTimeout.Token); + var stderr = await process.StandardError.ReadToEndAsync(extractCts.Token); + try { await process.WaitForExitAsync(extractCts.Token); } + catch (OperationCanceledException) when (extractTimeout.IsCancellationRequested) + { + _logger.LogError("FFmpeg audio extraction timed out after 30 minutes"); + if (!process.HasExited) process.Kill(entireProcessTree: true); + return false; + } + + if (process.ExitCode != 0) + { + _logger.LogError("FFmpeg exited {Code}: {Err}", process.ExitCode, stderr); + return false; + } + + if (!File.Exists(wavPath)) + { + _logger.LogError("FFmpeg did not produce output: {Wav}", wavPath); + return false; + } + + _logger.LogInformation("Audio extracted: {Wav} ({Bytes} bytes)", + wavPath, new FileInfo(wavPath).Length); + return true; + } + catch (Exception ex) + { + _logger.LogError(ex, "Exception during audio extraction from {Video}", videoPath); + return false; + } + } + + private async Task EnsureModelDownloadedAsync( + string modelName, CancellationToken cancellationToken) + { + var path = GetModelPath(modelName); + if (File.Exists(path)) + return path; + + var ok = await DownloadModelAsync(modelName, cancellationToken); + return ok ? path : null; + } + + public bool IsModelAvailable(string modelName) => File.Exists(GetModelPath(modelName)); + + private string GetModelPath(string modelName) + { + var file = modelName.ToLowerInvariant() switch + { + "tiny" => "ggml-tiny.bin", + "tiny.en" => "ggml-tiny.en.bin", + "base" => "ggml-base.bin", + "base.en" => "ggml-base.en.bin", + "small" => "ggml-small.bin", + "small.en" => "ggml-small.en.bin", + "medium" => "ggml-medium.bin", + "medium.en" => "ggml-medium.en.bin", + "large" => "ggml-large-v3.bin", + "large-v1" => "ggml-large-v1.bin", + "large-v2" => "ggml-large-v2.bin", + "large-v3" => "ggml-large-v3.bin", + "turbo" => "ggml-large-v3-turbo.bin", + _ => $"ggml-{modelName}.bin" + }; + return Path.Combine(_modelPath, file); + } + + private static string? GetModelDownloadUrl(string modelName) + { + const string base_url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main"; + return modelName.ToLowerInvariant() switch + { + "tiny" => $"{base_url}/ggml-tiny.bin", + "tiny.en" => $"{base_url}/ggml-tiny.en.bin", + "base" => $"{base_url}/ggml-base.bin", + "base.en" => $"{base_url}/ggml-base.en.bin", + "small" => $"{base_url}/ggml-small.bin", + "small.en" => $"{base_url}/ggml-small.en.bin", + "medium" => $"{base_url}/ggml-medium.bin", + "medium.en" => $"{base_url}/ggml-medium.en.bin", + "large" => $"{base_url}/ggml-large-v3.bin", + "large-v1" => $"{base_url}/ggml-large-v1.bin", + "large-v2" => $"{base_url}/ggml-large-v2.bin", + "large-v3" => $"{base_url}/ggml-large-v3.bin", + "turbo" => $"{base_url}/ggml-large-v3-turbo.bin", + _ => null + }; + } + + // โโ IDisposable โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) return; + if (disposing) + { + _httpClient?.Dispose(); + _binaryManager?.Dispose(); + } + _disposed = true; + } + + ~WhisperService() => Dispose(false); + } +} \ No newline at end of file diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Tasks/WhisperPostScanTask.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Tasks/WhisperPostScanTask.cs new file mode 100644 index 0000000..3a79fd6 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Tasks/WhisperPostScanTask.cs @@ -0,0 +1,109 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Jellyfin.Plugin.WhisperSubtitles.Configuration; +using Jellyfin.Plugin.WhisperSubtitles.Services; +using MediaBrowser.Controller.Library; +using MediaBrowser.Model.Tasks; +using Microsoft.Extensions.Logging; + +namespace Jellyfin.Plugin.WhisperSubtitles.Tasks +{ + /// + /// Runs after a library scan and optionally generates subtitles for new media. + /// + public class WhisperPostScanTask : ILibraryPostScanTask + { + private readonly ILibraryManager _libraryManager; + private readonly ILogger _logger; + private readonly ILoggerFactory _loggerFactory; + private readonly IWhisperService _whisperService; + private readonly ISubtitleDetectionService _subtitleDetectionService; + + public WhisperPostScanTask( + ILibraryManager libraryManager, + ILogger logger, + ILoggerFactory loggerFactory) + { + _libraryManager = libraryManager; + _logger = logger; + _loggerFactory = loggerFactory; + + _whisperService = new WhisperService(_loggerFactory.CreateLogger()); + _subtitleDetectionService = new SubtitleDetectionService(_loggerFactory.CreateLogger()); + } + + public string Name => "Whisper Post-Scan Processor"; + public string Key => "WhisperPostScan"; + public string Description => "Generates subtitles for newly scanned media when enabled in plugin configuration."; + + public async Task Run(IProgress progress, CancellationToken cancellationToken) + { + var config = Plugin.Instance?.Configuration ?? new PluginConfiguration(); + + if (!config.ProcessOnLibraryScan) + { + _logger.LogDebug("Post-scan processing disabled in configuration"); + return; + } + + _logger.LogInformation("Whisper post-scan starting..."); + + // Include Video so Home Video libraries are covered + var items = _libraryManager.GetItemList(new MediaBrowser.Controller.Entities.InternalItemsQuery + { + IncludeItemTypes = new[] + { + Jellyfin.Data.Enums.BaseItemKind.Movie, + Jellyfin.Data.Enums.BaseItemKind.Episode, + Jellyfin.Data.Enums.BaseItemKind.Video + }, + Recursive = true + }); + + foreach (var item in items) + { + if (cancellationToken.IsCancellationRequested) + { + _logger.LogInformation("Post-scan task cancelled"); + break; + } + + try + { + var path = item.Path; + if (string.IsNullOrEmpty(path) || !File.Exists(path)) + continue; + + if (_subtitleDetectionService.HasSubtitles( + path, config.TargetLanguage, config.AIIdentifier) && !config.RegenerateAI) + { + _logger.LogDebug("Skipping {Path} โ subtitle already present", path); + continue; + } + + var subtitlePath = _subtitleDetectionService.GetSubtitlePath( + path, config.TargetLanguage, config.AIIdentifier, "srt"); + + _logger.LogInformation("Post-scan: generating subtitles for {Path}", path); + + await _whisperService.GenerateSubtitleAsync( + path, subtitlePath, + config.WhisperModel.ToString(), + config.TargetLanguage, + config.TranslateToEnglish, + config.WordTimestamps, + null, + cancellationToken); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error during post-scan subtitle generation for {Item}", item.Path); + } + } + + _logger.LogInformation("Whisper post-scan complete."); + } + } +} \ No newline at end of file diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Tasks/WhisperSubtitleTask.cs b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Tasks/WhisperSubtitleTask.cs new file mode 100644 index 0000000..2f25442 --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/Tasks/WhisperSubtitleTask.cs @@ -0,0 +1,256 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Jellyfin.Data.Enums; +using Jellyfin.Plugin.WhisperSubtitles.Configuration; +using Jellyfin.Plugin.WhisperSubtitles.Services; +using MediaBrowser.Controller.Entities; +using MediaBrowser.Controller.Library; +using MediaBrowser.Model.Globalization; +using MediaBrowser.Model.Tasks; +using Microsoft.Extensions.Logging; + +namespace Jellyfin.Plugin.WhisperSubtitles.Tasks +{ + /// + /// Scheduled task โ generates subtitles for videos selected in the plugin config. + /// + public class WhisperSubtitleTask : IScheduledTask + { + private readonly ILibraryManager _libraryManager; + private readonly ILogger _logger; + private readonly ILocalizationManager _localization; + private readonly ILoggerFactory _loggerFactory; + private readonly IWhisperService _whisperService; + private readonly ISubtitleDetectionService _subtitleDetectionService; + + public WhisperSubtitleTask( + ILibraryManager libraryManager, + ILogger logger, + ILocalizationManager localization, + ILoggerFactory loggerFactory) + { + _libraryManager = libraryManager; + _logger = logger; + _localization = localization; + _loggerFactory = loggerFactory; + + _whisperService = new WhisperService(_loggerFactory.CreateLogger()); + _subtitleDetectionService = new SubtitleDetectionService(_loggerFactory.CreateLogger()); + } + + public string Name => "Generate Whisper Subtitles"; + public string Key => "WhisperSubtitleGeneration"; + public string Description => "Generates AI-powered subtitles for videos using OpenAI Whisper"; + public string Category => _localization.GetLocalizedString("TasksLibraryCategory"); + + public IEnumerable GetDefaultTriggers() => + Array.Empty(); + + public async Task ExecuteAsync(IProgress progress, CancellationToken cancellationToken) + { + var config = Plugin.Instance?.Configuration ?? new PluginConfiguration(); + + _logger.LogInformation( + "Whisper task starting. Model={M}, Language={L}, Translate={T}, Identifier={I}", + config.WhisperModel, config.TargetLanguage, config.TranslateToEnglish, config.AIIdentifier); + + var videos = GetVideoItems(config); + + if (videos.Count == 0) + { + _logger.LogInformation("No videos to process โ check library selection in plugin settings."); + progress?.Report(100); + return; + } + + _logger.LogInformation("Processing {Count} video(s)", videos.Count); + + var videoCount = videos.Count; + int processed = 0, skipped = 0, errors = 0; + double videoWeight = videoCount > 0 ? 100.0 / videoCount : 0; + + foreach (var video in videos) + { + if (cancellationToken.IsCancellationRequested) + { + _logger.LogInformation("Task cancelled by user"); + break; + } + + try + { + var videoPath = video.Path; + + if (ShouldSkip(videoPath, config)) + { + _logger.LogDebug("Skipping (has subtitles): {Path}", videoPath); + skipped++; + processed++; + progress?.Report((double)processed / videoCount * 100); + continue; + } + + var subtitlePath = _subtitleDetectionService.GetSubtitlePath( + videoPath, config.TargetLanguage, config.AIIdentifier, "srt"); + + _logger.LogInformation("Generating: {Path}", videoPath); + + double baseProgress = (double)processed / videoCount * 100; + var videoProgress = new Progress(p => + { + progress?.Report(baseProgress + p * videoWeight); + }); + + var ok = await _whisperService.GenerateSubtitleAsync( + videoPath, subtitlePath, + config.WhisperModel.ToString(), + config.TargetLanguage, + config.TranslateToEnglish, + config.WordTimestamps, + videoProgress, + cancellationToken); + + if (!ok) + { + _logger.LogError("Failed to generate subtitles: {Path}", videoPath); + errors++; + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Error processing: {Path}", video.Path); + errors++; + } + + processed++; + progress?.Report((double)processed / videoCount * 100); + } + + _logger.LogInformation( + "Task complete. Generated={G}, Skipped={S}, Errors={E}", + processed - skipped - errors, skipped, errors); + } + + // โโ Private helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ + + /// + /// Returns the list of videos to process, filtered by configured libraries. + /// Includes Movie, Episode, AND Video (home videos / generic video files). + /// + private List GetVideoItems(PluginConfiguration config) + { + var enabledLibraries = config.LibrariesToProcess ?? new List(); + + _logger.LogInformation( + "Configured library IDs ({Count}): [{Ids}]", + enabledLibraries.Count, + string.Join(", ", enabledLibraries)); + + // Include Video so "Home Videos" libraries are not silently ignored. + var query = new InternalItemsQuery + { + IncludeItemTypes = new[] { BaseItemKind.Movie, BaseItemKind.Episode, BaseItemKind.Video }, + IsVirtualItem = false, + Recursive = true + }; + + var allItems = _libraryManager.GetItemList(query); + + _logger.LogInformation("Library query returned {Count} item(s) before filtering", allItems.Count); + + // If no libraries are configured, process everything. + if (!enabledLibraries.Any()) + { + _logger.LogInformation("No library filter configured โ processing all {Count} item(s)", allItems.Count); + return allItems + .Where(i => !string.IsNullOrEmpty(i.Path) && File.Exists(i.Path)) + .ToList(); + } + + // The config page stores ItemIds from getVirtualFolders() which are hex strings + // WITHOUT hyphens (e.g. "ff6fbd42ce07adfc36b566506eba4f82"). + // Guid.ToString() produces hyphenated format (e.g. "ff6fbd42-ce07-adfc-36b5-66506eba4f82"). + // Normalize both sides to hyphenless lowercase so they always match. + var normalizedEnabled = new HashSet( + enabledLibraries + .Where(id => !string.IsNullOrWhiteSpace(id)) + .Select(NormalizeId), + StringComparer.OrdinalIgnoreCase); + + _logger.LogDebug("Normalized enabled IDs: [{Ids}]", string.Join(", ", normalizedEnabled)); + + var filtered = new List(); + + foreach (var item in allItems) + { + if (string.IsNullOrEmpty(item.Path) || !File.Exists(item.Path)) + continue; + + try + { + // GetCollectionFolders returns the top-level virtual folders (libraries) + // that contain this item. Each folder has a stable Guid Id. + var parentFolders = _libraryManager.GetCollectionFolders(item); + + // Check whether any of the item's parent libraries are in our allow-list. + // Normalize the folder ID to match the config format. + if (parentFolders.Any(f => normalizedEnabled.Contains(NormalizeId(f.Id.ToString())))) + { + filtered.Add(item); + _logger.LogDebug("Item '{Name}' matched library filter", item.Name); + } + else if (parentFolders.Count > 0) + { + _logger.LogDebug( + "Skipping '{Name}' โ parent libraries: [{Folders}]", + item.Name, + string.Join(", ", parentFolders.Select(f => NormalizeId(f.Id.ToString())))); + } + else + { + _logger.LogDebug("Item '{Name}' has no parent collection folders", item.Name); + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Error resolving library for '{Name}' โ skipping", item.Name); + } + } + + _logger.LogInformation( + "{Filtered} of {Total} item(s) match the configured libraries", + filtered.Count, allItems.Count); + + return filtered; + } + + /// + /// Strips hyphens and lowercases a GUID/ItemId string so that + /// "ff6fbd42ce07adfc36b566506eba4f82" and "ff6fbd42-ce07-adfc-36b5-66506eba4f82" + /// both normalize to the same value. + /// + private static string NormalizeId(string id) => + id.Replace("-", string.Empty).ToLowerInvariant(); + + private bool ShouldSkip(string videoPath, PluginConfiguration config) + { + var hasAI = _subtitleDetectionService.HasAISubtitles( + videoPath, config.TargetLanguage, config.AIIdentifier); + + if (hasAI && !config.RegenerateAI) + return true; + + if (config.SkipExisting && !hasAI) + { + if (_subtitleDetectionService.HasSubtitles(videoPath, config.TargetLanguage)) + return true; + } + + return false; + } + } +} \ No newline at end of file diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.deps.json b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.deps.json new file mode 100644 index 0000000..eb53b0a --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.deps.json @@ -0,0 +1,41 @@ +{ + "runtimeTarget": { + "name": ".NETCoreApp,Version=v9.0", + "signature": "" + }, + "compilationOptions": {}, + "targets": { + ".NETCoreApp,Version=v9.0": { + "Jellyfin.Plugin.WhisperSubtitles/0.0.0.98": { + "dependencies": { + "Whisper.net": "1.8.1" + }, + "runtime": { + "Jellyfin.Plugin.WhisperSubtitles.dll": {} + } + }, + "Whisper.net/1.8.1": { + "runtime": { + "lib/net9.0/Whisper.net.dll": { + "assemblyVersion": "1.8.1.0", + "fileVersion": "1.8.1.0" + } + } + } + } + }, + "libraries": { + "Jellyfin.Plugin.WhisperSubtitles/0.0.0.98": { + "type": "project", + "serviceable": false, + "sha512": "" + }, + "Whisper.net/1.8.1": { + "type": "package", + "serviceable": true, + "sha512": "sha512-NOcPlajorIQ5bIaf21PESirTvhCSb/T4qbUZreoU31jOu2W6kGJRgy+1f/M7PK/atcWXLF5FsXIj2lHi5sym1w==", + "path": "whisper.net/1.8.1", + "hashPath": "whisper.net.1.8.1.nupkg.sha512" + } + } +} \ No newline at end of file diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.dll b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.dll new file mode 100644 index 0000000..727f322 Binary files /dev/null and b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.dll differ diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.pdb b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.pdb new file mode 100644 index 0000000..6e121e4 Binary files /dev/null and b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Jellyfin.Plugin.WhisperSubtitles.pdb differ diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Whisper.net.dll b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Whisper.net.dll new file mode 100755 index 0000000..178f382 Binary files /dev/null and b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/Whisper.net.dll differ diff --git a/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/ggml-metal.metal b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/ggml-metal.metal new file mode 100755 index 0000000..1c0ca5a --- /dev/null +++ b/Jellyfin.Plugin.WhisperSubtitles/Jellyfin.Plugin.WhisperSubtitles/bin/Debug/net9.0/ggml-metal.metal @@ -0,0 +1,6823 @@ +#define GGML_COMMON_DECL_METAL +#define GGML_COMMON_IMPL_METAL +#if defined(GGML_METAL_EMBED_LIBRARY) +__embed_ggml-common.h__ +#else +#include "ggml-common.h" +#endif +#include "ggml-metal-impl.h" + +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +// +// cmd: +// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal +// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal +// +#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16) +#undef GGML_METAL_USE_BF16 +#endif + +#if defined(GGML_METAL_USE_BF16) +typedef matrix bfloat4x4; +#endif + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} + +#if defined(GGML_METAL_USE_BF16) +template +void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} +#endif + +template +void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md; + reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m; + reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg_f[i/2][2*(i%2) + 0] = d * x0 + md; + reg_f[i/2][2*(i%2) + 1] = d * x1 + md; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + md; + reg[2*ii + 1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg_f[i/2][2*(i%2) + 0] = d * x0 + m; + reg_f[i/2][2*(i%2) + 1] = d * x1 + m; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + m; + reg[2*ii + 1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + + float4x4 reg_f; + + for (int i = 0; i < 16; i++) { + reg_f[i/4][i%4] = (qs[i + 16*il] * d); + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il < 2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.f; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint16_t * ql = (device const uint16_t *)xb->ql; + device const uint16_t * qh = (device const uint16_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); + qh = qh + 16*(il/8) + 8*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); + const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; + const float ml = d_all * sc * 32.f; + const float dl0 = d_all * sc; + const float dl1 = dl0 / 256.f; + const float dl2 = dl0 / (256.f * 256.f); + const float dl3 = dl0 / (256.f * 256.f * 256.f); + const uint8_t shr_h = il>2 ? 2 : 0; + const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); + const uint8_t shr_l = il>1 ? 4 : 0; + for (int i = 0; i < 4; ++i) { + const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; + const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; + const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); + reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml; + reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml; + reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml; + reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml; + } +} + +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; + + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; + + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f; + reg[0] = d * kvalues_iq4nl_f[q8[0]]; + reg[1] = d * kvalues_iq4nl_f[q8[1]]; + reg[2] = d * kvalues_iq4nl_f[q8[2]]; + reg[3] = d * kvalues_iq4nl_f[q8[3]]; +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims +// cons: not very efficient +kernel void kernel_add( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_div( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +template +kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + +kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] / src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_elu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); +} + +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_sum_rows & args, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < args.ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + +template +kernel void kernel_soft_max( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_soft_max & args, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + + device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + + float slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +template +kernel void kernel_soft_max_4( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_soft_max & args, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + + float slope = 1.0f; + + if (args.max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_diag_mask_inf & args, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > args.n_past + i01) { + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY; + } else { + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant ggml_metal_kargs_diag_mask_inf & args, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01; + const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= args.n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > args.n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 +kernel void kernel_ssm_conv_f32( + device const void * src0, + device const void * src1, + device float * dst, + constant ggml_metal_kargs_ssm_conv & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +kernel void kernel_ssm_scan_f32( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device float * dst, + constant ggml_metal_kargs_ssm_scan & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i3 = tgpig.y; + + const int64_t nc = args.d_state; + // const int64_t nr = args.d_inner; + const int64_t n_t = args.n_seq_tokens; + // const int64_t n_s = args.n_seqs; + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); + device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); + + if (i2 > 0) { + s0 = s; + } + + // i1 == 0 + float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + int64_t i = i0; + float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + } +} + +kernel void kernel_rwkv_wkv6_f32( + device const float * k, + device const float * v, + device const float * r, + device const float * tf, + device const float * td, + device const float * state_in, + device float * dst, + constant uint & B, + constant uint & T, + constant uint & C, + constant uint & H, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = 64; // TODO: support head_size = 128 + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H) { + return; + } + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + threadgroup float _k[head_size]; + threadgroup float _r[head_size]; + threadgroup float _tf[head_size]; + threadgroup float _td[head_size]; + + float state[head_size]; + + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + _tf[tid] = tf[head_id * head_size + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + threadgroup_barrier(mem_flags::mem_threadgroup); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float v_val = v[t]; + float y = 0.0; + + for (uint j = 0; j < head_size; j += 4) { + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + + float4 kv = k_vec * v_val; + + float4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + + dst[t] = y; + } + + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} + +kernel void kernel_rwkv_wkv7_f32( + device const float * r, + device const float * w, + device const float * k, + device const float * v, + device const float * a, + device const float * b, + device const float * state_in, + device float * dst, + constant uint & B, + constant uint & T, + constant uint & C, + constant uint & H, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = 64; // TODO: support head_size = 128 + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H) { + return; + } + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + threadgroup float _r[head_size]; + threadgroup float _w[head_size]; + threadgroup float _k[head_size]; + threadgroup float _a[head_size]; + threadgroup float _b[head_size]; + + float state[head_size]; + + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + threadgroup_barrier(mem_flags::mem_threadgroup); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float v_val = v[t]; + float y = 0.0, sa = 0.0; + + float4 sa_vec(0.0); + + for (uint j = 0; j < head_size; j += 4) { + float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + sa_vec += a_vec * s_vec; + } + sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3]; + + for (uint j = 0; j < head_size; j += 4) { + float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + + float4 kv = k_vec * v_val; + + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(s_vec, r_vec); + + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + + dst[t] = y; + } + + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} + +kernel void kernel_argmax( + device const void * x, + device int32_t * dst, + constant int64_t & ncols, + constant uint64_t & nb01, + threadgroup float * shared_maxval [[threadgroup(0)]], + threadgroup int32_t * shared_argmax [[threadgroup(1)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + + float lmax = -INFINITY; + int32_t larg = -1; + + for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + if (x_row[i00] > lmax) { + lmax = x_row[i00]; + larg = i00; + } + } + + // find the argmax value in the block + float max_val = simd_max(lmax); + int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + shared_maxval[tiisg] = -INFINITY; + shared_argmax[tiisg] = -1; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shared_maxval[sgitg] = max_val; + shared_argmax[sgitg] = arg_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = shared_maxval[tiisg]; + arg_val = shared_argmax[tiisg]; + + float max_val_reduced = simd_max(max_val); + int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); + + dst[tgpig] = arg_val_reduced; + + return; + } + + dst[tgpig] = arg_val; +} + +kernel void kernel_norm( + constant ggml_metal_kargs_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float4 sumf4(0.0f); + + float sumf = 0.0f; + + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf4 += x[i00]; + } + sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + + sumf = 0.0f; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] - mean; + sumf += dot(y[i00], y[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float variance = sumf/args.ne00; + + const float scale = 1.0f/sqrt(variance + args.eps); + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + constant ggml_metal_kargs_rms_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_l2_norm( + constant ggml_metal_kargs_l2_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float scale = 1.0f/sqrt(max(sumf, args.eps)); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_group_norm & args, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = args.ne00*args.ne01*args.ne02; + const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + args.eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; +} + +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + + return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; +} + +template +void mul_vec_q_n_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.ne00/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q_type * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; + + device const float * yb = y + ix*QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy[2] = { 0.f, 0.f }; + +#pragma unroll + for (short i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; + } + +#pragma unroll + for (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q4_1_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q5_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +#define NB_Q8_0 8 + +template +void kernel_mul_mv_q8_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.ne00/QK8_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q8_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); + } + + float yl[NB_Q8_0]; + float sumf[nr0] = { 0.f }; + + const short ix = tiisg/4; + const short il = tiisg%4; + + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (short i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (short row = 0; row < nr0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + float sumq = 0.f; + for (short iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*ax[row][ib].d; + } + + yb += nw*NB_Q8_0; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +// mat-vec kernel processing in chunks of float4 +// chpb - chunks per quantization block +template +void kernel_mul_mv_ext_q4_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 4; // chunks per thread + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4 * y4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; // current chunk index + + for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { + float4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] += chpt*nxpsg; + } + } + + // reduce only the threads in each row + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +// mat-vec kernel processing in chunks of float4x4 +template +void kernel_mul_mv_ext_q4x4_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 1; + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4x4 * y4x4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; + + for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) { + float4x4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4x4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += + dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) + + dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) + + dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) + + dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] += chpt*nxpsg; + } + } + + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +// dispatchers needed for compile-time nxpsg +// epb - elements per quantization block +template +kernel void kernel_mul_mv_ext_q4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +template +kernel void kernel_mul_mv_ext_q4x4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; +typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; + +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>; + +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>; + +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>; + +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; + +#define N_MV_T_T 4 + +template +void kernel_mul_mv_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg) { + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + device const T0 * x = (device const T0 *) (src0 + offset0); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + if (args.ne00 < 128) { + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= args.ne11) { + break; + } + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + + float sumf = 0; + for (int i = tiisg; i < args.ne00; i += 32) { + sumf += (T0) x[i] * (T1) y[i]; + } + + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; + } + } + } else { + device const T04 * x4 = (device const T04 *) x; + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= args.ne11) { + break; + } + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + device const T14 * y4 = (device const T14 *) y; + + float sumf = 0; + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], (float4) y4[i]); + } + + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; + } + } + } +} + +template +kernel void kernel_mul_mv( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_impl( + args, + src0, + src1, + dst, + tgpig, + tiisg); +} + +typedef decltype(kernel_mul_mv) mul_mv_t; + +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; +#endif + +template +kernel void kernel_mul_mv_1row( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T * x = (device const T *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + float sumf = 0; + if (args.ne00 < 128) { + for (int i = tiisg; i < args.ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[r0] = sum_all; + } + } else { + device const T4 * x4 = (device const T4 *) x; + device const float4 * y4 = (device const float4 *) y; + + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); + } + + float sum_all = simd_sum(sumf); + + if (tiisg == 0) { + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[r0] = sum_all; + } + } +} + +typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; + +template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#endif + +// Assumes row size (ne00) is a multiple of 4 +template +kernel void kernel_mul_mv_l4( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = args.ne11; + const int r0 = tgpig.x; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + device const T4 * x4 = (device const T4 *) (src0 + offset0); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + for (int r1 = 0; r1 < nrows; ++r1) { + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const float4 * y4 = (device const float4 *) (src1 + offset1); + + float sumf = 0; + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); + } + + float sum_all = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; + } + } +} + +typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; + +template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#endif + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); +} + +template +kernel void kernel_rope_norm( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_neox( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; + +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; + +typedef void (im2col_t)( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +// const int64_t IC = tgpg[0]; + const int64_t OH = tgpg[1]; + const int64_t OW = tgpg[2]; + +// const int64_t N = ntg[0]; + const int64_t KH = ntg[1]; + const int64_t KW = ntg[2]; + + const int64_t in = tpitg[0]; + const int64_t ikh = tpitg[1]; + const int64_t ikw = tpitg[2]; + + const int64_t iic = tgpig[0]; + const int64_t ioh = tgpig[1]; + const int64_t iow = tgpig[2]; + + const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; + const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; + + const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; + pdst[offset_dst] = x[offset_src]; + } +} + +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +typedef void (im2col_ext_t)( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + device const float * x, + device char * dst, + constant ggml_metal_kargs_im2col & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = (int64_t)args.KHW; + + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= args.N) { + return; + } + + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; + + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; + +typedef void (conv_transpose_1d_t)( + device const float * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_1d( + device const T * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]) { + + float v = 0.0f; + + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; + const int32_t input_offset = c * args.IL; + + for (int64_t i = 0; i < args.IL; i++) { + if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { + v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; + } + } + } + + device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1); + + dst_ptr[0] = v; +} + +template [[host_name("kernel_conv_transpose_1d_f32_f32")]] +kernel void kernel_conv_transpose_1d( + device const float * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template [[host_name("kernel_conv_transpose_1d_f16_f32")]] +kernel void kernel_conv_transpose_1d( + device const half * src0, + device const float * src1, + device char * dst, + constant ggml_metal_kargs_conv_transpose_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_upscale & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3/args.sf3; + const int64_t i02 = i2/args.sf2; + const int64_t i01 = i1/args.sf1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int64_t i00 = i0/args.sf0; + + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_ptr[0] = src0_ptr[0]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_pad & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +kernel void kernel_pad_reflect_1d_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_pad_reflect_1d & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.p0) { + dst_ptr[i0] = src0_ptr[args.p0 - i0]; + } else if (i0 < args.ne0 - args.p1) { + dst_ptr[i0] = src0_ptr[i0 - args.p0]; + } else { + dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1]; + } + } + } +} + +kernel void kernel_arange_f32( + device char * dst, + constant ggml_metal_kargs_arange & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_ptr[i0] = args.start + args.step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_timestep_embedding & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*args.nb1); + + int half_ = args.dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)args.max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (args.dim % 2 != 0 && tpitg.x == 0) { + embed_data[args.dim] = 0.f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant ggml_metal_kargs_argsort & args, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant ggml_metal_kargs_argsort & args, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= args.ncols_pad) return; + + device const float * x_row = x + row * args.ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= args.ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= args.ncols || + (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= args.ncols || + (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < args.ncols) { + dst[row * args.ncols + col] = dst_row[col]; + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant ggml_metal_kargs_leaky_relu & args, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; +} + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // key type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short KV = 8, // key/value processed per each simdgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]*Q; + + const short DK4 = DK/4; + const short DK8 = DK/8; + const short DK16 = DK/16; + const short DV4 = DV/4; + const short DV8 = DV/8; + const short DV16 = DV/16; + const short NW = N_SIMDWIDTH; + const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + + const short TS = nsg*SH; // shared memory size per query in (s_t == float) + const short T = DK + 2*TS; // shared memory size per query in (half) + + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix + + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + o8x8_t lo[DV8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 + j < args.ne01) { + sq4[j*DK4 + i] = (q4_t) q4[i]; + } else { + sq4[j*DK4 + i] = (q4_t) 0.0f; + } + } + } + + // zero out lo + for (short i = 0; i < DV8; ++i) { + lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TS + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S[Q] = { [0 ... Q-1] = 0.0f }; + half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; + + // thread indices inside the simdgroup + // TODO: see if we can utilize quad-group functions for better performance + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) + const short tx = tiisg%4; + const short ty = tiisg/4; + + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + const bool has_mask = mask != q; + + half slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= args.ne11) { + break; + } + + if (has_mask) { + // used to detect blocks full of -INF + half smax = -INFINITY; + + // load the mask in shared memory + #pragma unroll(Q) + for (short j = 0; j < Q; ++j) { + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + + const half m = pm[ic + tiisg]; + + ss[j*TS + C + tiisg] = m; + smax = max(smax, m); + } + + smax = simd_max(smax); + + if (smax == -INFINITY) { + continue; + } + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + #pragma unroll(DK8) + for (short i = 0; i < DK8; ++i) { + k8x8_t mk; + simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + + q8x8_t mq; + simdgroup_load(mq, sq + i*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } + } + } + + // cast qk_t -> s_t + //s8x8_t mqks(1.0f); + //simdgroup_multiply(mqks, mqk, mqks); + //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); + + simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + } + } + + // online softmax + { + for (ushort j = 0; j < Q; ++j) { + const half m = M[j]; + + // scale and apply the logitcap / mask + half s = ss[j*TS + tiisg]*args.scale; + + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); + } + + // mqk = mqk + mask*slope + s += slope*ss[j*TS + C + tiisg]; + + M[j] = simd_max(max(M[j], s)); + + const half ms = exp(m - M[j]); + const half vs = exp(s - M[j]); + + S[j] = S[j]*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TS + tiisg] = vs; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*TS + 2*C + j] = ms; + } + } + } + + // O = diag(ms)*O + { + s8x8_t mm; + simdgroup_load(mm, ss + 2*C, TS, 0, false); + + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t ms; + simdgroup_load(ms, ss + 8*cc, TS, 0, false); + + if (is_same::value) { + // we can read directly from global memory + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + v8x8_t mv; + simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + + simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); + } + } else { + for (short ii = 0; ii < DV16; ii += 4) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + if (DV16%4 == 0) { + // no need for bound checks + { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } else { + if (ii + tx < DV16) { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DV16; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (ushort sg = 1; sg < nsg; ++sg) { + half S = { 0.0f }; + half M = { -__FLT16_MAX__/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const half S0 = ss[j*TS + 0]; + const half S1 = ss[j*TS + sg*SH + 0]; + + const half M0 = ss[j*TS + 1]; + const half M1 = ss[j*TS + sg*SH + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; + + ss[j*TS + 2*C + j ] = ms0; + ss[j*TS + 2*C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + s8x8_t ms0; + s8x8_t ms1; + + simdgroup_load(ms0, ss + 2*C, TS, 0, false); + simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); + + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + o8x8_t t; + + simdgroup_load (t, so + i*8, DV, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { + const float S = ss[j*TS + 0]; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; + } + } + } +} + +// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as +// template to be able to explore different combinations +// +#define FA_TYPES \ + half, half4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 + +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +#endif + +template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#undef FA_TYPES + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // key type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]; + + const short DK4 = DK/4; + const short DV4 = DV/4; + const short NW = N_SIMDWIDTH; + const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads + const short SH = 2*C; // shared memory per simdgroup + + const short T = DK + nsg*SH; // shared memory size per query in (half) + + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results + + // store the result for all queries in local memory (the O matrix from the paper) + o4_t lo[DV4/NL]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 < args.ne01) { + sq4[i] = (q4_t) q4[i]; + } else { + sq4[i] = (q4_t) 0.0f; + } + } + + // zero out lo + for (short i = 0; i < DV4/NL; ++i) { + lo[i] = (o4_t) 0.0f; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = (s4_t) 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S = 0.0f; + half M = -__FLT16_MAX__/2; + + // thread indices inside the simdgroup + const short tx = tiisg%NL; + const short ty = tiisg/NL; + + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + const bool has_mask = mask != q; + + // pointer to the mask + device const half * pm = (device const half *) (mask + iq1*args.nb31); + + half slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= args.ne11) { + break; + } + + if (has_mask) { + sm[tiisg] = pm[ic + tiisg]; + } + + // Q*K^T + { + // each simdgroup processes 1 query and NE (NW/NL) head elements + for (short cc = 0; cc < C/NE; ++cc) { + qk_t mqk = 0.0f; + + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + #pragma unroll(DK4/NL) + for (short ii = 0; ii < DK4; ii += NL) { + const short i = ii + tx; + + k4_t mk; + deq_k_t4(pk + i/nl_k, i%nl_k, mk); + + // note: this is less precise than the version below + //mqka[0] += dot(mq[0], mk[0]); + //mqka[1] += dot(mq[1], mk[1]); + //mqka[2] += dot(mq[2], mk[2]); + //mqka[3] += dot(mq[3], mk[3]); + + //q4x4_t mq = sq4x4[i]; + //mqka[0] += dot((float4) mq[0], (float4) mk[0]); + //mqka[1] += dot((float4) mq[1], (float4) mk[1]); + //mqka[2] += dot((float4) mq[2], (float4) mk[2]); + //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + + mqk += dot((float4) mk, (float4) sq4[i]); + } + + static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails + + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk += simd_shuffle_down(mqk, 16); + } + if (NE <= 2) { + mqk += simd_shuffle_down(mqk, 8); + } + if (NE <= 4) { + mqk += simd_shuffle_down(mqk, 4); + } + if (NE <= 8) { + mqk += simd_shuffle_down(mqk, 2); + } + if (NE <= 16) { + mqk += simd_shuffle_down(mqk, 1); + } + + // mqk = mqk*scale + mask*slope + if (tx == 0) { + mqk *= args.scale; + + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); + } + + mqk += sm[NE*cc + ty]*slope; + + ss[NE*cc + ty] = mqk; + } + } + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // online softmax + { + const half m = M; + const half s = ss[tiisg]; + + M = simd_max(max(M, s)); + + const half ms = exp(m - M); + const half vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[tiisg] = vs; + + // O = diag(ms)*O + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + lo[ii/NL] *= ms; + } + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // O = O + (Q*K^T)*V + { + //#pragma unroll(C/NE) + for (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + const s4_t ms(ss[NE*cc + ty]); + + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + const short i = ii + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii/NL] += mv*ms; + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = (s_t) S; + ss[1] = (s_t) M; + } + } + + // simdgroup reduce (NE = 4) + // [ 0, 8, 16, 24] -> [ 0] + // [ 1, 9, 17, 25] -> [ 1] + // [ 2, 10, 18, 26] -> [ 2] + // [ 3, 11, 19, 27] -> [ 3] + // [ 4, 12, 20, 28] -> [ 4] + // [ 5, 13, 21, 29] -> [ 5] + // [ 6, 14, 22, 30] -> [ 6] + // [ 7, 15, 23, 31] -> [ 7] + for (short ii = 0; ii < DV4; ii += NL) { + if (NE > 1) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + } + + if (NE > 2) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + } + + if (NE > 4) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + } + + if (NE > 8) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + } + + if (NE > 16) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // store results to shared memory + for (short i = tiisg; i < DV4; i += NL) { + sr4[i] = lo[i/NL]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; + + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; + + const half M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short i = tiisg; i < DV4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; + } + } +} + +// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem +// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max +// +#define FA_TYPES \ + half4, \ + half4, \ + half4, \ + float, \ + half, half4, \ + half4 + +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +#undef FA_TYPES + +template +kernel void kernel_set( + constant ggml_metal_kargs_set & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i13 = tgpig[2]; + const int i12 = tgpig[1]; + const int i11 = tgpig[0]; + + const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; + + const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); + const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); + const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; + + device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); + + for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { + device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); + dst_data[i10] = (T) src[0]; + } +} + +typedef decltype(kernel_set) kernel_set_t; + +template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; +template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; + +template
movie.en.IDENTIFIER.srt
--gpus all
ffprobe
Loading target server libraries...
~/.config/jellyfin/plugins/Whisper Subtitles_*/whisper/
~/.cache/whisper-cpp/
Generates high-fidelity subtitle tracking sheets automatically. Harnesses advanced local transcription engines without needing to route external data packets out to third party provider cloud spaces.