Skip to content

Commit 545c5f5

Browse files
committed
Add AI Toolkit and fix a couple a1111/forge bugs
1 parent f3d717c commit 545c5f5

10 files changed

Lines changed: 318 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
77

88
## v2.15.0-pre.2
99
### Added
10+
- Added new package - [AI Toolkit](https://github.com/ostris/ai-toolkit/)
1011
- Added Manual Install button for installing Package extensions that aren't in the indexes
1112
- Added Next and Previous buttons to the Civitai details page to navigate between results
1213
- Added Negative Rejection Steering (NRS) by @reithan to Inference
@@ -16,6 +17,9 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
1617
### Fixed
1718
- Fixed Inference custom step (e.g. HiresFix) Samplers potentially sharing state with other card UIs like model browser.
1819
- Fixed extension manager failing to install extensions due to incorrect clone directory
20+
- Fixed duplicate Python versions appearing in the Advanced Options when installing a package
21+
- Fixed [#1360](https://github.com/LykosAI/StabilityMatrix/issues/1360) - A1111 install not using correct torch for 5000-series GPUs
22+
- Fixed [#1361](https://github.com/LykosAI/StabilityMatrix/issues/1361) - numpy and other Forge startup errors
1923

2024
## v2.15.0-pre.1
2125
### Added

StabilityMatrix.Avalonia/Helpers/UnixPrerequisiteHelper.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,22 @@ public async Task RunNpm(
416416
);
417417
}
418418

419+
public AnsiProcess RunNpmDetached(
420+
ProcessArgs args,
421+
string? workingDirectory = null,
422+
Action<ProcessOutput>? onProcessOutput = null,
423+
IReadOnlyDictionary<string, string>? envVars = null
424+
)
425+
{
426+
return ProcessRunner.StartAnsiProcess(
427+
NpmPath,
428+
args,
429+
workingDirectory,
430+
onProcessOutput,
431+
envVars ?? new Dictionary<string, string>()
432+
);
433+
}
434+
419435
[SupportedOSPlatform("Linux")]
420436
[SupportedOSPlatform("macOS")]
421437
public async Task<Process> RunDotnet(

StabilityMatrix.Avalonia/Helpers/WindowsPrerequisiteHelper.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,22 @@ public async Task RunNpm(
192192
onProcessOutput?.Invoke(ProcessOutput.FromStdErrLine(result.StandardError));
193193
}
194194

195+
public AnsiProcess RunNpmDetached(
196+
ProcessArgs args,
197+
string? workingDirectory = null,
198+
Action<ProcessOutput>? onProcessOutput = null,
199+
IReadOnlyDictionary<string, string>? envVars = null
200+
)
201+
{
202+
return ProcessRunner.StartAnsiProcess(
203+
NodeExistsPath,
204+
args,
205+
workingDirectory,
206+
onProcessOutput,
207+
envVars ?? new Dictionary<string, string>()
208+
);
209+
}
210+
195211
public Task InstallPackageRequirements(
196212
BasePackage package,
197213
PyVersion? pyVersion = null,

StabilityMatrix.Core/Helper/Factory/PackageFactory.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage)
229229
prerequisiteHelper,
230230
pyInstallationManager
231231
),
232+
"ai-toolkit" => new AiToolkit(
233+
githubApiCache,
234+
settingsManager,
235+
downloadService,
236+
prerequisiteHelper,
237+
pyInstallationManager
238+
),
232239
_ => throw new ArgumentOutOfRangeException(nameof(installedPackage)),
233240
};
234241
}

StabilityMatrix.Core/Helper/IPrerequisiteHelper.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ Task RunNpm(
258258
Action<ProcessOutput>? onProcessOutput = null,
259259
IReadOnlyDictionary<string, string>? envVars = null
260260
);
261+
262+
AnsiProcess RunNpmDetached(
263+
ProcessArgs args,
264+
string? workingDirectory = null,
265+
Action<ProcessOutput>? onProcessOutput = null,
266+
IReadOnlyDictionary<string, string>? envVars = null
267+
);
261268
Task InstallNodeIfNecessary(IProgress<ProgressReport>? progress = null);
262269
Task InstallPackageRequirements(
263270
BasePackage package,

StabilityMatrix.Core/Models/Packages/A3WebUI.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,18 +216,22 @@ public override async Task InstallPackage(
216216
progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true));
217217

218218
var torchVersion = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion();
219+
var isBlackwell =
220+
torchVersion is TorchIndex.Cuda
221+
&& (SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu());
219222

220223
var requirements = new FilePath(installLocation, "requirements_versions.txt");
221224
var pipArgs = torchVersion switch
222225
{
223226
TorchIndex.Mps => new PipInstallArgs().WithTorch("==2.3.1").WithTorchVision("==0.18.1"),
224227
_ => new PipInstallArgs()
225-
.WithTorch("==2.1.2")
226-
.WithTorchVision("==0.16.2")
228+
.WithTorch(isBlackwell ? string.Empty : "==2.1.2")
229+
.WithTorchVision(isBlackwell ? string.Empty : "==0.16.2")
227230
.WithTorchExtraIndex(
228231
torchVersion switch
229232
{
230233
TorchIndex.Cpu => "cpu",
234+
TorchIndex.Cuda when isBlackwell => "cu128",
231235
TorchIndex.Cuda => "cu121",
232236
TorchIndex.Rocm => "rocm5.6",
233237
TorchIndex.Mps => "cpu",
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
using System.Collections.Immutable;
2+
using System.Text.RegularExpressions;
3+
using Injectio.Attributes;
4+
using StabilityMatrix.Core.Extensions;
5+
using StabilityMatrix.Core.Helper;
6+
using StabilityMatrix.Core.Helper.Cache;
7+
using StabilityMatrix.Core.Helper.HardwareInfo;
8+
using StabilityMatrix.Core.Models.FileInterfaces;
9+
using StabilityMatrix.Core.Models.Progress;
10+
using StabilityMatrix.Core.Processes;
11+
using StabilityMatrix.Core.Python;
12+
using StabilityMatrix.Core.Services;
13+
14+
namespace StabilityMatrix.Core.Models.Packages;
15+
16+
[RegisterSingleton<BasePackage, AiToolkit>(Duplicate = DuplicateStrategy.Append)]
17+
public class AiToolkit(
18+
IGithubApiCache githubApi,
19+
ISettingsManager settingsManager,
20+
IDownloadService downloadService,
21+
IPrerequisiteHelper prerequisiteHelper,
22+
IPyInstallationManager pyInstallationManager
23+
) : BaseGitPackage(githubApi, settingsManager, downloadService, prerequisiteHelper, pyInstallationManager)
24+
{
25+
private AnsiProcess? npmProcess;
26+
27+
public override string Name => "ai-toolkit";
28+
public override string DisplayName { get; set; } = "AI-Toolkit";
29+
public override string Author => "ostris";
30+
public override string Blurb => "AI Toolkit is an all in one training suite for diffusion models";
31+
public override string LicenseType => "MIT";
32+
public override string LicenseUrl => "https://github.com/ostris/ai-toolkit/blob/main/LICENSE";
33+
public override string LaunchCommand => string.Empty;
34+
35+
public override Uri PreviewImageUri =>
36+
new(
37+
"https://camo.githubusercontent.com/ea35b399e0d659f9f2ee09cbedb58e1a3ec7a0eab763e8ae8d11d076aad5be40/68747470733a2f2f6f73747269732e636f6d2f77702d636f6e74656e742f75706c6f6164732f323032352f30322f746f6f6c6b69742d75692e6a7067"
38+
);
39+
40+
public override string OutputFolderName => "output";
41+
public override IEnumerable<TorchIndex> AvailableTorchIndices => [TorchIndex.Cuda];
42+
public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Advanced;
43+
public override SharedFolderMethod RecommendedSharedFolderMethod => SharedFolderMethod.None;
44+
public override List<LaunchOptionDefinition> LaunchOptions => [];
45+
public override Dictionary<SharedOutputType, IReadOnlyList<string>>? SharedOutputFolders => [];
46+
public override string MainBranch => "main";
47+
public override bool IsCompatible => HardwareHelper.HasNvidiaGpu();
48+
49+
public override TorchIndex GetRecommendedTorchVersion() => TorchIndex.Cuda;
50+
51+
public override PackageType PackageType => PackageType.SdTraining;
52+
public override bool OfferInOneClickInstaller => false;
53+
public override bool ShouldIgnoreReleases => true;
54+
public override PyVersion RecommendedPythonVersion => Python.PyInstallationManager.Python_3_12_10;
55+
56+
public override IEnumerable<PackagePrerequisite> Prerequisites =>
57+
base.Prerequisites.Concat([PackagePrerequisite.Node]);
58+
59+
public override async Task InstallPackage(
60+
string installLocation,
61+
InstalledPackage installedPackage,
62+
InstallPackageOptions options,
63+
IProgress<ProgressReport>? progress = null,
64+
Action<ProcessOutput>? onConsoleOutput = null,
65+
CancellationToken cancellationToken = default
66+
)
67+
{
68+
progress?.Report(new ProgressReport(-1, "Setting up venv", isIndeterminate: true));
69+
await using var venvRunner = await SetupVenvPure(
70+
installLocation,
71+
pythonVersion: options.PythonOptions.PythonVersion
72+
)
73+
.ConfigureAwait(false);
74+
venvRunner.UpdateEnvironmentVariables(GetEnvVars);
75+
76+
await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false);
77+
78+
var isBlackwell =
79+
SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu();
80+
var pipArgs = new PipInstallArgs()
81+
.AddArg("--upgrade")
82+
.WithTorch("==2.7.0")
83+
.WithTorchVision("==0.22.0")
84+
.WithTorchAudio("==2.7.0")
85+
.WithTorchExtraIndex(isBlackwell ? "cu128" : "cu126");
86+
87+
if (installedPackage.PipOverrides != null)
88+
{
89+
pipArgs = pipArgs.WithUserOverrides(installedPackage.PipOverrides);
90+
}
91+
92+
progress?.Report(new ProgressReport(-1f, "Installing torch...", isIndeterminate: true));
93+
await venvRunner.PipInstall(pipArgs, onConsoleOutput).ConfigureAwait(false);
94+
95+
// install requirements.txt
96+
var requirements = new FilePath(installLocation, "requirements.txt");
97+
98+
pipArgs = new PipInstallArgs("--upgrade")
99+
.WithParsedFromRequirementsTxt(
100+
await requirements.ReadAllTextAsync(cancellationToken).ConfigureAwait(false),
101+
excludePattern: "torch$|numpy"
102+
)
103+
.AddArg(Compat.IsWindows ? "triton-windows" : "triton");
104+
105+
if (installedPackage.PipOverrides != null)
106+
{
107+
pipArgs = pipArgs.WithUserOverrides(installedPackage.PipOverrides);
108+
}
109+
110+
progress?.Report(
111+
new ProgressReport(-1f, "Installing Package Requirements...", isIndeterminate: true)
112+
);
113+
await venvRunner.PipInstall(pipArgs, onConsoleOutput).ConfigureAwait(false);
114+
115+
progress?.Report(new ProgressReport(-1f, "Installing AI Toolkit UI...", isIndeterminate: true));
116+
117+
var uiDirectory = new DirectoryPath(installLocation, "ui");
118+
var envVars = GetEnvVars(venvRunner.EnvironmentVariables);
119+
await PrerequisiteHelper
120+
.RunNpm("install", uiDirectory, progress?.AsProcessOutputHandler(), envVars)
121+
.ConfigureAwait(false);
122+
await PrerequisiteHelper
123+
.RunNpm("run update_db", uiDirectory, progress?.AsProcessOutputHandler(), envVars)
124+
.ConfigureAwait(false);
125+
await PrerequisiteHelper
126+
.RunNpm("run build", uiDirectory, progress?.AsProcessOutputHandler(), envVars)
127+
.ConfigureAwait(false);
128+
}
129+
130+
public override async Task RunPackage(
131+
string installLocation,
132+
InstalledPackage installedPackage,
133+
RunPackageOptions options,
134+
Action<ProcessOutput>? onConsoleOutput = null,
135+
CancellationToken cancellationToken = default
136+
)
137+
{
138+
await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion))
139+
.ConfigureAwait(false);
140+
VenvRunner.UpdateEnvironmentVariables(GetEnvVars);
141+
142+
var uiDirectory = new DirectoryPath(installLocation, "ui");
143+
var envVars = GetEnvVars(VenvRunner.EnvironmentVariables);
144+
npmProcess = PrerequisiteHelper.RunNpmDetached(
145+
"run start",
146+
uiDirectory,
147+
HandleConsoleOutput,
148+
envVars
149+
);
150+
npmProcess.EnableRaisingEvents = true;
151+
if (Compat.IsWindows)
152+
{
153+
ProcessTracker.AttachExitHandlerJobToProcess(npmProcess);
154+
}
155+
156+
return;
157+
158+
void HandleConsoleOutput(ProcessOutput s)
159+
{
160+
onConsoleOutput?.Invoke(s);
161+
162+
if (!s.Text.Contains("Local: ", StringComparison.OrdinalIgnoreCase))
163+
return;
164+
165+
var regex = new Regex(@"(https?:\/\/)([^:\s]+):(\d+)");
166+
var match = regex.Match(s.Text);
167+
if (match.Success)
168+
{
169+
WebUrl = match.Value;
170+
}
171+
OnStartupComplete(WebUrl);
172+
}
173+
}
174+
175+
public override async Task WaitForShutdown()
176+
{
177+
if (npmProcess is { HasExited: false })
178+
{
179+
npmProcess.Kill(true);
180+
try
181+
{
182+
await npmProcess
183+
.WaitForExitAsync(new CancellationTokenSource(5000).Token)
184+
.ConfigureAwait(false);
185+
}
186+
catch (OperationCanceledException e)
187+
{
188+
Console.WriteLine(e);
189+
}
190+
}
191+
192+
npmProcess = null;
193+
GC.SuppressFinalize(this);
194+
}
195+
196+
private ImmutableDictionary<string, string> GetEnvVars(ImmutableDictionary<string, string> env)
197+
{
198+
var pathBuilder = new EnvPathBuilder();
199+
200+
if (env.TryGetValue("PATH", out var value))
201+
{
202+
pathBuilder.AddPath(value);
203+
}
204+
205+
pathBuilder.AddPath(
206+
Compat.IsWindows
207+
? Environment.GetFolderPath(Environment.SpecialFolder.System)
208+
: Path.Combine(SettingsManager.LibraryDir, "Assets", "nodejs", "bin")
209+
);
210+
211+
pathBuilder.AddPath(Path.Combine(SettingsManager.LibraryDir, "Assets", "nodejs"));
212+
213+
return env.SetItem("PATH", pathBuilder.ToString());
214+
}
215+
}

StabilityMatrix.Core/Models/Packages/ForgeClassic.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,22 @@ public override async Task InstallPackage(
160160
.ReadAllTextAsync(cancellationToken)
161161
.ConfigureAwait(false);
162162

163+
var extensionsBuiltinDir = new DirectoryPath(installLocation, "extensions-builtin");
164+
if (extensionsBuiltinDir.Exists)
165+
{
166+
var requirementsFiles = extensionsBuiltinDir.EnumerateFiles(
167+
"requirements.txt",
168+
EnumerationOptionConstants.AllDirectories
169+
);
170+
171+
foreach (var requirementsFile in requirementsFiles)
172+
{
173+
requirementsContent += await requirementsFile
174+
.ReadAllTextAsync(cancellationToken)
175+
.ConfigureAwait(false);
176+
}
177+
}
178+
163179
var isLegacyNvidia =
164180
SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() ?? HardwareHelper.HasLegacyNvidiaGpu();
165181
var torchExtraIndex = isLegacyNvidia ? "cu126" : "cu128";

0 commit comments

Comments
 (0)