Skip to content

Commit f3d717c

Browse files
authored
Merge pull request LykosAI#1126 from ionite34/add-nrs
Add negative rejection steering (NRS) addon to Inference
2 parents bd12e1f + cc7639f commit f3d717c

11 files changed

Lines changed: 214 additions & 28 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
99
### Added
1010
- Added Manual Install button for installing Package extensions that aren't in the indexes
1111
- Added Next and Previous buttons to the Civitai details page to navigate between results
12+
- Added Negative Rejection Steering (NRS) by @reithan to Inference
1213
### Changed
1314
- Brought back the "size remaining after download" tooltip in the new Civitai details page
1415
- Updated ComfyUI installs for AMD users on Linux to use the latest rocm6.4 torch index

StabilityMatrix.Avalonia/App.axaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
<StyleInclude Source="Controls/MarkdownViewer.axaml" />
100100
<StyleInclude Source="Controls/Inference/WanModelCard.axaml" />
101101
<StyleInclude Source="Controls/Inference/PlasmaNoiseCard.axaml" />
102+
<StyleInclude Source="Controls/Inference/NrsCard.axaml" />
102103
<labs:ControlThemes />
103104

104105
<Style Selector="DockControl">
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
<Styles
2+
xmlns="https://github.com/avaloniaui"
3+
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
4+
xmlns:controls="using:StabilityMatrix.Avalonia.Controls"
5+
xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData"
6+
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
7+
xmlns:vmInference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
8+
x:DataType="vmInference:NrsCardViewModel">
9+
<Design.PreviewWith>
10+
<controls:NrsCard DataContext="{x:Static mocks:DesignData.NrsCardViewModel}" />
11+
</Design.PreviewWith>
12+
13+
<Style Selector="controls|NrsCard">
14+
<Setter Property="Template">
15+
<ControlTemplate>
16+
<controls:Card x:Name="PART_Card">
17+
<controls:Card.Styles>
18+
<Style Selector="ui|NumberBox">
19+
<Setter Property="Margin" Value="12,4,0,4" />
20+
<Setter Property="MinWidth" Value="70" />
21+
<Setter Property="HorizontalAlignment" Value="Stretch" />
22+
<Setter Property="ValidationMode" Value="InvalidInputOverwritten" />
23+
<Setter Property="SimpleNumberFormat" Value="F2" />
24+
<Setter Property="SpinButtonPlacementMode" Value="Inline" />
25+
</Style>
26+
</controls:Card.Styles>
27+
<Grid ColumnDefinitions="Auto,*" RowDefinitions="*,*,*,*">
28+
<TextBlock
29+
Grid.Row="0"
30+
Grid.Column="0"
31+
VerticalAlignment="Center"
32+
Text="Skew" />
33+
<ui:NumberBox
34+
Grid.Row="0"
35+
Grid.Column="1"
36+
Margin="12,0,0,4"
37+
SmallChange="0.25"
38+
Value="{Binding Skew}" />
39+
40+
<TextBlock
41+
Grid.Row="1"
42+
Grid.Column="0"
43+
VerticalAlignment="Center"
44+
Text="Stretch" />
45+
<ui:NumberBox
46+
Grid.Row="1"
47+
Grid.Column="1"
48+
SmallChange="0.25"
49+
Value="{Binding Stretch}" />
50+
51+
<TextBlock
52+
Grid.Row="2"
53+
Grid.Column="0"
54+
VerticalAlignment="Center"
55+
Text="Squash" />
56+
<ui:NumberBox
57+
Grid.Row="2"
58+
Grid.Column="1"
59+
SmallChange="0.01"
60+
Value="{Binding Squash}" />
61+
62+
</Grid>
63+
</controls:Card>
64+
</ControlTemplate>
65+
</Setter>
66+
</Style>
67+
</Styles>
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using Injectio.Attributes;
2+
3+
namespace StabilityMatrix.Avalonia.Controls;
4+
5+
[RegisterTransient<NrsCard>]
6+
public class NrsCard : TemplatedControlBase { }

StabilityMatrix.Avalonia/DesignData/DesignData.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,8 @@ public static UpdateSettingsViewModel UpdateSettingsViewModel
10681068

10691069
public static FreeUCardViewModel FreeUCardViewModel => DialogFactory.Get<FreeUCardViewModel>();
10701070

1071+
public static NrsCardViewModel NrsCardViewModel => DialogFactory.Get<NrsCardViewModel>();
1072+
10711073
public static PromptCardViewModel PromptCardViewModel =>
10721074
DialogFactory.Get<PromptCardViewModel>(vm =>
10731075
{

StabilityMatrix.Avalonia/StabilityMatrix.Avalonia.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@
263263
<DependentUpon>InferenceWanTextToVideoView.axaml</DependentUpon>
264264
<SubType>Code</SubType>
265265
</Compile>
266+
<Compile Update="Controls\Inference\NrsCard.axaml.cs">
267+
<DependentUpon>NrsCard.axaml</DependentUpon>
268+
<SubType>Code</SubType>
269+
</Compile>
266270
</ItemGroup>
267271

268272
<!-- set HUSKY to 0 to disable, or opt-in during CI by setting HUSKY to 1 -->

StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
2525
[JsonDerivedType(typeof(DiscreteModelSamplingCardViewModel), DiscreteModelSamplingCardViewModel.ModuleKey)]
2626
[JsonDerivedType(typeof(RescaleCfgCardViewModel), RescaleCfgCardViewModel.ModuleKey)]
2727
[JsonDerivedType(typeof(PlasmaNoiseCardViewModel), PlasmaNoiseCardViewModel.ModuleKey)]
28+
[JsonDerivedType(typeof(NrsCardViewModel), NrsCardViewModel.ModuleKey)]
2829
[JsonDerivedType(typeof(FreeUModule))]
2930
[JsonDerivedType(typeof(HiresFixModule))]
3031
[JsonDerivedType(typeof(FluxHiresFixModule))]
@@ -39,6 +40,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
3940
[JsonDerivedType(typeof(DiscreteModelSamplingModule))]
4041
[JsonDerivedType(typeof(RescaleCfgModule))]
4142
[JsonDerivedType(typeof(PlasmaNoiseModule))]
43+
[JsonDerivedType(typeof(NRSModule))]
4244
public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
4345
{
4446
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
@@ -47,8 +49,10 @@ public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
4749

4850
private static readonly string[] SerializerIgnoredNames = { nameof(HasErrors) };
4951

50-
private static readonly JsonSerializerOptions SerializerOptions =
51-
new() { IgnoreReadOnlyProperties = true };
52+
private static readonly JsonSerializerOptions SerializerOptions = new()
53+
{
54+
IgnoreReadOnlyProperties = true,
55+
};
5256

5357
private static bool ShouldIgnoreProperty(PropertyInfo property)
5458
{
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using Injectio.Attributes;
2+
using StabilityMatrix.Avalonia.Models.Inference;
3+
using StabilityMatrix.Avalonia.Services;
4+
using StabilityMatrix.Avalonia.ViewModels.Base;
5+
using StabilityMatrix.Core.Attributes;
6+
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
7+
8+
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
9+
10+
[ManagedService]
11+
[RegisterTransient<NRSModule>]
12+
public class NRSModule : ModuleBase
13+
{
14+
/// <inheritdoc />
15+
public NRSModule(IServiceManager<ViewModelBase> vmFactory)
16+
: base(vmFactory)
17+
{
18+
Title = "Negative Rejection Steering (NRS)";
19+
AddCards(vmFactory.Get<NrsCardViewModel>());
20+
}
21+
22+
/// <summary>
23+
/// Applies NRS to the Model property
24+
/// </summary>
25+
protected override void OnApplyStep(ModuleApplyStepEventArgs e)
26+
{
27+
var card = GetCard<NrsCardViewModel>();
28+
29+
// Currently applies to all models
30+
// TODO: Add option to apply to either base or refiner
31+
32+
foreach (var modelConnections in e.Builder.Connections.Models.Values.Where(m => m.Model is not null))
33+
{
34+
var nrsOutput = e
35+
.Nodes.AddTypedNode(
36+
new ComfyNodeBuilder.NRS
37+
{
38+
Name = e.Nodes.GetUniqueName($"NRS_{modelConnections.Name}"),
39+
Model = modelConnections.Model!,
40+
Skew = card.Skew,
41+
Stretch = card.Stretch,
42+
Squash = card.Squash,
43+
}
44+
)
45+
.Output;
46+
47+
modelConnections.Model = nrsOutput;
48+
e.Temp.Base.Model = nrsOutput;
49+
}
50+
}
51+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System.ComponentModel.DataAnnotations;
2+
using CommunityToolkit.Mvvm.ComponentModel;
3+
using Injectio.Attributes;
4+
using StabilityMatrix.Avalonia.Controls;
5+
using StabilityMatrix.Avalonia.ViewModels.Base;
6+
using StabilityMatrix.Core.Attributes;
7+
8+
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
9+
10+
[View(typeof(NrsCard))]
11+
[ManagedService]
12+
[RegisterTransient<NrsCardViewModel>]
13+
public partial class NrsCardViewModel : LoadableViewModelBase
14+
{
15+
public const string ModuleKey = "NRS";
16+
17+
[ObservableProperty]
18+
[NotifyDataErrorInfo]
19+
[Required]
20+
[Range(-30.0d, 30.0d)]
21+
public partial double Skew { get; set; } = 4;
22+
23+
[ObservableProperty]
24+
[NotifyDataErrorInfo]
25+
[Required]
26+
[Range(-30.0d, 30.0d)]
27+
public partial double Stretch { get; set; } = 2;
28+
29+
[ObservableProperty]
30+
[NotifyDataErrorInfo]
31+
[Required]
32+
[Range(0d, 1d)]
33+
public partial double Squash { get; set; } = 0;
34+
}

StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ TabContext tabContext
154154
typeof(DiscreteModelSamplingModule),
155155
typeof(RescaleCfgModule),
156156
typeof(PlasmaNoiseModule),
157+
typeof(NRSModule),
157158
];
158159
});
159160
}

0 commit comments

Comments
 (0)