@@ -30,6 +30,56 @@ public async Task ExecuteAsync(IProgress<ProgressReport>? progress = null)
3030 ) ;
3131 }
3232
33+ var venvDir = WorkingDirectory . JoinDir ( "venv" ) ;
34+
35+ await using var venvRunner = PyBaseInstall . Default . CreateVenvRunner (
36+ venvDir ,
37+ workingDirectory : WorkingDirectory ,
38+ environmentVariables : EnvironmentVariables
39+ ) ;
40+
41+ var torchInfo = await venvRunner . PipShow ( "torch" ) . ConfigureAwait ( false ) ;
42+ var sageWheelUrl = string . Empty ;
43+
44+ if ( torchInfo == null )
45+ {
46+ sageWheelUrl = string . Empty ;
47+ }
48+ else if ( torchInfo . Version . Contains ( "2.5.1" ) && torchInfo . Version . Contains ( "cu124" ) )
49+ {
50+ sageWheelUrl =
51+ "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu124torch2.5.1-cp310-cp310-win_amd64.whl" ;
52+ }
53+ else if ( torchInfo . Version . Contains ( "2.6.0" ) && torchInfo . Version . Contains ( "cu126" ) )
54+ {
55+ sageWheelUrl =
56+ "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp310-cp310-win_amd64.whl" ;
57+ }
58+ else if ( torchInfo . Version . Contains ( "2.7.0" ) && torchInfo . Version . Contains ( "cu128" ) )
59+ {
60+ sageWheelUrl =
61+ "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl" ;
62+ }
63+
64+ var pipArgs = new PipInstallArgs ( ) ;
65+ if ( IsBlackwellGpu )
66+ {
67+ pipArgs = pipArgs . AddArg ( "--pre" ) ;
68+ }
69+ pipArgs = pipArgs . AddArg ( "triton-windows" ) ;
70+
71+ if ( ! string . IsNullOrWhiteSpace ( sageWheelUrl ) )
72+ {
73+ pipArgs = pipArgs . AddArg ( sageWheelUrl ) ;
74+
75+ progress ? . Report (
76+ new ProgressReport ( - 1f , message : "Installing Triton & SageAttention" , isIndeterminate : true )
77+ ) ;
78+ await venvRunner . PipInstall ( pipArgs , progress . AsProcessOutputHandler ( ) ) . ConfigureAwait ( false ) ;
79+ return ;
80+ }
81+
82+ // no wheels, gotta build
3383 if ( ! prerequisiteHelper . IsVcBuildToolsInstalled )
3484 {
3585 throw new MissingPrerequisiteException (
@@ -63,14 +113,6 @@ public async Task ExecuteAsync(IProgress<ProgressReport>? progress = null)
63113 : cuda126ExpectedPath . JoinFile ( "nvcc.exe" ) . ToString ( ) ;
64114 }
65115
66- var venvDir = WorkingDirectory . JoinDir ( "venv" ) ;
67-
68- await using var venvRunner = PyBaseInstall . Default . CreateVenvRunner (
69- venvDir ,
70- workingDirectory : WorkingDirectory ,
71- environmentVariables : EnvironmentVariables
72- ) ;
73-
74116 venvRunner . UpdateEnvironmentVariables ( env =>
75117 {
76118 var cudaBinPath = Path . GetDirectoryName ( nvccPath ) ! ;
@@ -88,47 +130,6 @@ public async Task ExecuteAsync(IProgress<ProgressReport>? progress = null)
88130 return env ;
89131 } ) ;
90132
91- var torchInfo = await venvRunner . PipShow ( "torch" ) . ConfigureAwait ( false ) ;
92- var sageWheelUrl = string . Empty ;
93-
94- if ( torchInfo == null )
95- {
96- sageWheelUrl = string . Empty ;
97- }
98- else if ( torchInfo . Version . Contains ( "2.5.1" ) && torchInfo . Version . Contains ( "cu124" ) )
99- {
100- sageWheelUrl =
101- "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu124torch2.5.1-cp310-cp310-win_amd64.whl" ;
102- }
103- else if ( torchInfo . Version . Contains ( "2.6.0" ) && torchInfo . Version . Contains ( "cu126" ) )
104- {
105- sageWheelUrl =
106- "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp310-cp310-win_amd64.whl" ;
107- }
108- else if ( torchInfo . Version . Contains ( "2.7.0" ) && torchInfo . Version . Contains ( "cu128" ) )
109- {
110- sageWheelUrl =
111- "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl" ;
112- }
113-
114- var pipArgs = new PipInstallArgs ( ) ;
115- if ( IsBlackwellGpu )
116- {
117- pipArgs = pipArgs . AddArg ( "--pre" ) ;
118- }
119- pipArgs = pipArgs . AddArg ( "triton-windows" ) ;
120-
121- if ( ! string . IsNullOrWhiteSpace ( sageWheelUrl ) )
122- {
123- pipArgs = pipArgs . AddArg ( sageWheelUrl ) ;
124-
125- progress ? . Report (
126- new ProgressReport ( - 1f , message : "Installing Triton & SageAttention" , isIndeterminate : true )
127- ) ;
128- await venvRunner . PipInstall ( pipArgs , progress . AsProcessOutputHandler ( ) ) . ConfigureAwait ( false ) ;
129- return ;
130- }
131-
132133 progress ? . Report ( new ProgressReport ( - 1f , message : "Installing Triton" , isIndeterminate : true ) ) ;
133134
134135 await venvRunner . PipInstall ( pipArgs , progress . AsProcessOutputHandler ( ) ) . ConfigureAwait ( false ) ;
@@ -138,19 +139,7 @@ public async Task ExecuteAsync(IProgress<ProgressReport>? progress = null)
138139 progress ? . Report (
139140 new ProgressReport ( - 1f , message : "Downloading Python libraries" , isIndeterminate : true )
140141 ) ;
141- var downloadPath = WorkingDirectory . JoinFile ( "python_libs_for_sage.zip" ) ;
142- await downloadService
143- . DownloadToFileAsync ( PythonLibsDownloadUrl , downloadPath , progress )
144- . ConfigureAwait ( false ) ;
145-
146- progress ? . Report (
147- new ProgressReport ( - 1f , message : "Extracting Python libraries" , isIndeterminate : true )
148- ) ;
149- await ArchiveHelper . Extract7Z ( downloadPath , venvDir , progress ) . ConfigureAwait ( false ) ;
150-
151- var includeFolder = venvDir . JoinDir ( "include" ) ;
152- var scriptsIncludeFolder = venvDir . JoinDir ( "Scripts" ) . JoinDir ( "include" ) ;
153- await includeFolder . CopyToAsync ( scriptsIncludeFolder ) . ConfigureAwait ( false ) ;
142+ await AddMissingLibsToVenv ( WorkingDirectory , progress ) . ConfigureAwait ( false ) ;
154143
155144 var sageDir = WorkingDirectory . JoinDir ( "SageAttention" ) ;
156145
@@ -176,5 +165,40 @@ await venvRunner
176165 . ConfigureAwait ( false ) ;
177166 }
178167
168+ private async Task AddMissingLibsToVenv (
169+ DirectoryPath installedPackagePath ,
170+ IProgress < ProgressReport > ? progress = null
171+ )
172+ {
173+ var venvLibsDir = installedPackagePath . JoinDir ( "venv" , "libs" ) ;
174+ var venvIncludeDir = installedPackagePath . JoinDir ( "venv" , "include" ) ;
175+ if (
176+ venvLibsDir . Exists
177+ && venvIncludeDir . Exists
178+ && venvLibsDir . JoinFile ( "python3.lib" ) . Exists
179+ && venvLibsDir . JoinFile ( "python310.lib" ) . Exists
180+ )
181+ {
182+ return ;
183+ }
184+
185+ var downloadPath = installedPackagePath . JoinFile ( "python_libs_for_sage.zip" ) ;
186+ var venvDir = installedPackagePath . JoinDir ( "venv" ) ;
187+ await downloadService
188+ . DownloadToFileAsync ( PythonLibsDownloadUrl , downloadPath , progress )
189+ . ConfigureAwait ( false ) ;
190+
191+ progress ? . Report (
192+ new ProgressReport ( - 1f , message : "Extracting Python libraries" , isIndeterminate : true )
193+ ) ;
194+ await ArchiveHelper . Extract7Z ( downloadPath , venvDir , progress ) . ConfigureAwait ( false ) ;
195+
196+ var includeFolder = venvDir . JoinDir ( "include" ) ;
197+ var scriptsIncludeFolder = venvDir . JoinDir ( "Scripts" ) . JoinDir ( "include" ) ;
198+ await includeFolder . CopyToAsync ( scriptsIncludeFolder ) . ConfigureAwait ( false ) ;
199+
200+ await downloadPath . DeleteAsync ( ) . ConfigureAwait ( false ) ;
201+ }
202+
179203 public string ProgressTitle => "Installing Triton and SageAttention" ;
180204}
0 commit comments