|
| 1 | +using System.Collections.Immutable; |
| 2 | +using System.Text.RegularExpressions; |
| 3 | +using Injectio.Attributes; |
| 4 | +using StabilityMatrix.Core.Extensions; |
| 5 | +using StabilityMatrix.Core.Helper; |
| 6 | +using StabilityMatrix.Core.Helper.Cache; |
| 7 | +using StabilityMatrix.Core.Helper.HardwareInfo; |
| 8 | +using StabilityMatrix.Core.Models.FileInterfaces; |
| 9 | +using StabilityMatrix.Core.Models.Progress; |
| 10 | +using StabilityMatrix.Core.Processes; |
| 11 | +using StabilityMatrix.Core.Python; |
| 12 | +using StabilityMatrix.Core.Services; |
| 13 | + |
| 14 | +namespace StabilityMatrix.Core.Models.Packages; |
| 15 | + |
| 16 | +[RegisterSingleton<BasePackage, AiToolkit>(Duplicate = DuplicateStrategy.Append)] |
| 17 | +public class AiToolkit( |
| 18 | + IGithubApiCache githubApi, |
| 19 | + ISettingsManager settingsManager, |
| 20 | + IDownloadService downloadService, |
| 21 | + IPrerequisiteHelper prerequisiteHelper, |
| 22 | + IPyInstallationManager pyInstallationManager |
| 23 | +) : BaseGitPackage(githubApi, settingsManager, downloadService, prerequisiteHelper, pyInstallationManager) |
| 24 | +{ |
| 25 | + private AnsiProcess? npmProcess; |
| 26 | + |
| 27 | + public override string Name => "ai-toolkit"; |
| 28 | + public override string DisplayName { get; set; } = "AI-Toolkit"; |
| 29 | + public override string Author => "ostris"; |
| 30 | + public override string Blurb => "AI Toolkit is an all in one training suite for diffusion models"; |
| 31 | + public override string LicenseType => "MIT"; |
| 32 | + public override string LicenseUrl => "https://github.com/ostris/ai-toolkit/blob/main/LICENSE"; |
| 33 | + public override string LaunchCommand => string.Empty; |
| 34 | + |
| 35 | + public override Uri PreviewImageUri => |
| 36 | + new( |
| 37 | + "https://camo.githubusercontent.com/ea35b399e0d659f9f2ee09cbedb58e1a3ec7a0eab763e8ae8d11d076aad5be40/68747470733a2f2f6f73747269732e636f6d2f77702d636f6e74656e742f75706c6f6164732f323032352f30322f746f6f6c6b69742d75692e6a7067" |
| 38 | + ); |
| 39 | + |
| 40 | + public override string OutputFolderName => "output"; |
| 41 | + public override IEnumerable<TorchIndex> AvailableTorchIndices => [TorchIndex.Cuda]; |
| 42 | + public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced; |
| 43 | + public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.None; |
| 44 | + public override List<LaunchOptionDefinition> LaunchOptions => []; |
| 45 | + public override Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders => []; |
| 46 | + public override string MainBranch => "main"; |
| 47 | + public override bool IsCompatible => HardwareHelper.HasNvidiaGpu(); |
| 48 | + |
| 49 | + public override TorchIndex GetRecommendedTorchVersion() => TorchIndex.Cuda; |
| 50 | + |
| 51 | + public override PackageType PackageType => PackageType.SdTraining; |
| 52 | + public override bool OfferInOneClickInstaller => false; |
| 53 | + public override bool ShouldIgnoreReleases => true; |
| 54 | + public override PyVersion RecommendedPythonVersion => Python.PyInstallationManager.Python_3_12_10; |
| 55 | + |
| 56 | + public override IEnumerable<PackagePrerequisite> Prerequisites => |
| 57 | + base.Prerequisites.Concat([PackagePrerequisite.Node]); |
| 58 | + |
| 59 | + public override async Task InstallPackage( |
| 60 | + string installLocation, |
| 61 | + InstalledPackage installedPackage, |
| 62 | + InstallPackageOptions options, |
| 63 | + IProgress<ProgressReport>? progress = null, |
| 64 | + Action<ProcessOutput>? onConsoleOutput = null, |
| 65 | + CancellationToken cancellationToken = default |
| 66 | + ) |
| 67 | + { |
| 68 | + progress?.Report(new ProgressReport(-1, "Setting up venv", isIndeterminate: true)); |
| 69 | + await using var venvRunner = await SetupVenvPure( |
| 70 | + installLocation, |
| 71 | + pythonVersion: options.PythonOptions.PythonVersion |
| 72 | + ) |
| 73 | + .ConfigureAwait(false); |
| 74 | + venvRunner.UpdateEnvironmentVariables(GetEnvVars); |
| 75 | + |
| 76 | + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); |
| 77 | + |
| 78 | + var isBlackwell = |
| 79 | + SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu(); |
| 80 | + var pipArgs = new PipInstallArgs() |
| 81 | + .AddArg("--upgrade") |
| 82 | + .WithTorch("==2.7.0") |
| 83 | + .WithTorchVision("==0.22.0") |
| 84 | + .WithTorchAudio("==2.7.0") |
| 85 | + .WithTorchExtraIndex(isBlackwell ? "cu128" : "cu126"); |
| 86 | + |
| 87 | + if (installedPackage.PipOverrides != null) |
| 88 | + { |
| 89 | + pipArgs = pipArgs.WithUserOverrides(installedPackage.PipOverrides); |
| 90 | + } |
| 91 | + |
| 92 | + progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true)); |
| 93 | + await venvRunner.PipInstall(pipArgs, onConsoleOutput).ConfigureAwait(false); |
| 94 | + |
| 95 | + // install requirements.txt |
| 96 | + var requirements = new FilePath(installLocation, "requirements.txt"); |
| 97 | + |
| 98 | + pipArgs = new PipInstallArgs("--upgrade") |
| 99 | + .WithParsedFromRequirementsTxt( |
| 100 | + await requirements.ReadAllTextAsync(cancellationToken).ConfigureAwait(false), |
| 101 | + excludePattern: "torch$|numpy" |
| 102 | + ) |
| 103 | + .AddArg(Compat.IsWindows ? "triton-windows" : "triton"); |
| 104 | + |
| 105 | + if (installedPackage.PipOverrides != null) |
| 106 | + { |
| 107 | + pipArgs = pipArgs.WithUserOverrides(installedPackage.PipOverrides); |
| 108 | + } |
| 109 | + |
| 110 | + progress?.Report( |
| 111 | + new ProgressReport(-1f, "Installing Package Requirements...", isIndeterminate: true) |
| 112 | + ); |
| 113 | + await venvRunner.PipInstall(pipArgs, onConsoleOutput).ConfigureAwait(false); |
| 114 | + |
| 115 | + progress?.Report(new ProgressReport(-1f, "Installing AI Toolkit UI...", isIndeterminate: true)); |
| 116 | + |
| 117 | + var uiDirectory = new DirectoryPath(installLocation, "ui"); |
| 118 | + var envVars = GetEnvVars(venvRunner.EnvironmentVariables); |
| 119 | + await PrerequisiteHelper |
| 120 | + .RunNpm("install", uiDirectory, progress?.AsProcessOutputHandler(), envVars) |
| 121 | + .ConfigureAwait(false); |
| 122 | + await PrerequisiteHelper |
| 123 | + .RunNpm("run update_db", uiDirectory, progress?.AsProcessOutputHandler(), envVars) |
| 124 | + .ConfigureAwait(false); |
| 125 | + await PrerequisiteHelper |
| 126 | + .RunNpm("run build", uiDirectory, progress?.AsProcessOutputHandler(), envVars) |
| 127 | + .ConfigureAwait(false); |
| 128 | + } |
| 129 | + |
| 130 | + public override async Task RunPackage( |
| 131 | + string installLocation, |
| 132 | + InstalledPackage installedPackage, |
| 133 | + RunPackageOptions options, |
| 134 | + Action<ProcessOutput>? onConsoleOutput = null, |
| 135 | + CancellationToken cancellationToken = default |
| 136 | + ) |
| 137 | + { |
| 138 | + await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) |
| 139 | + .ConfigureAwait(false); |
| 140 | + VenvRunner.UpdateEnvironmentVariables(GetEnvVars); |
| 141 | + |
| 142 | + var uiDirectory = new DirectoryPath(installLocation, "ui"); |
| 143 | + var envVars = GetEnvVars(VenvRunner.EnvironmentVariables); |
| 144 | + npmProcess = PrerequisiteHelper.RunNpmDetached( |
| 145 | + "run start", |
| 146 | + uiDirectory, |
| 147 | + HandleConsoleOutput, |
| 148 | + envVars |
| 149 | + ); |
| 150 | + npmProcess.EnableRaisingEvents = true; |
| 151 | + if (Compat.IsWindows) |
| 152 | + { |
| 153 | + ProcessTracker.AttachExitHandlerJobToProcess(npmProcess); |
| 154 | + } |
| 155 | + |
| 156 | + return; |
| 157 | + |
| 158 | + void HandleConsoleOutput(ProcessOutput s) |
| 159 | + { |
| 160 | + onConsoleOutput?.Invoke(s); |
| 161 | + |
| 162 | + if (!s.Text.Contains("Local: ", StringComparison.OrdinalIgnoreCase)) |
| 163 | + return; |
| 164 | + |
| 165 | + var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)"); |
| 166 | + var match = regex.Match(s.Text); |
| 167 | + if (match.Success) |
| 168 | + { |
| 169 | + WebUrl = match.Value; |
| 170 | + } |
| 171 | + OnStartupComplete(WebUrl); |
| 172 | + } |
| 173 | + } |
| 174 | + |
| 175 | + public override async Task WaitForShutdown() |
| 176 | + { |
| 177 | + if (npmProcess is { HasExited: false }) |
| 178 | + { |
| 179 | + npmProcess.Kill(true); |
| 180 | + try |
| 181 | + { |
| 182 | + await npmProcess |
| 183 | + .WaitForExitAsync(new CancellationTokenSource(5000).Token) |
| 184 | + .ConfigureAwait(false); |
| 185 | + } |
| 186 | + catch (OperationCanceledException e) |
| 187 | + { |
| 188 | + Console.WriteLine(e); |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + npmProcess = null; |
| 193 | + GC.SuppressFinalize(this); |
| 194 | + } |
| 195 | + |
| 196 | + private ImmutableDictionary<string, string> GetEnvVars(ImmutableDictionary<string, string> env) |
| 197 | + { |
| 198 | + var pathBuilder = new EnvPathBuilder(); |
| 199 | + |
| 200 | + if (env.TryGetValue("PATH", out var value)) |
| 201 | + { |
| 202 | + pathBuilder.AddPath(value); |
| 203 | + } |
| 204 | + |
| 205 | + pathBuilder.AddPath( |
| 206 | + Compat.IsWindows |
| 207 | + ? Environment.GetFolderPath(Environment.SpecialFolder.System) |
| 208 | + : Path.Combine(SettingsManager.LibraryDir, "Assets", "nodejs", "bin") |
| 209 | + ); |
| 210 | + |
| 211 | + pathBuilder.AddPath(Path.Combine(SettingsManager.LibraryDir, "Assets", "nodejs")); |
| 212 | + |
| 213 | + return env.SetItem("PATH", pathBuilder.ToString()); |
| 214 | + } |
| 215 | +} |
0 commit comments