11using System . Data . Common ;
2+ using System . Linq . Expressions ;
3+ using System . Text ;
4+
5+ using EntityFrameworkCore . ExecuteInsert . OnConflict ;
26
37using Microsoft . EntityFrameworkCore ;
8+ using Microsoft . EntityFrameworkCore . Metadata ;
49
510using Npgsql ;
611
@@ -15,7 +20,7 @@ public class PostgreSqlBulkInsertProvider : BulkInsertProviderBase
1520 protected override string CreateTableCopySql => "CREATE TEMPORARY TABLE {0} AS TABLE {1} WITH NO DATA;" ;
1621
1722 //language=sql
18- protected override string AddTableCopyBulkInsertId => "ALTER TABLE {0} ADD COLUMN _bulk_insert_id SERIAL PRIMARY KEY;" ;
23+ protected override string AddTableCopyBulkInsertId => $ "ALTER TABLE {{0}} ADD COLUMN { BulkInsertId } SERIAL PRIMARY KEY;";
1924
2025 private string GetBinaryImportCommand ( DbContext context , Type entityType , string tableName )
2126 {
@@ -24,6 +29,204 @@ private string GetBinaryImportCommand(DbContext context, Type entityType, string
2429 return $ "COPY { tableName } ({ string . Join ( ", " , columns ) } ) FROM STDIN (FORMAT BINARY)";
2530 }
2631
32+ protected override string BuildInsertSelectQuery < T > ( string tableName ,
33+ string targetTableName ,
34+ IProperty [ ] insertedProperties ,
35+ IProperty [ ] properties ,
36+ BulkInsertOptions options , OnConflictOptions ? onConflict = null )
37+ {
38+ var insertedColumns = insertedProperties . Select ( p => Escape ( p . GetColumnName ( ) ) ) ;
39+ var insertedColumnList = string . Join ( ", " , insertedColumns ) ;
40+
41+ var returnedColumns = properties . Select ( p => $ "{ Escape ( p . GetColumnName ( ) ) } AS { Escape ( p . Name ) } ") ;
42+ var columnList = string . Join ( ", " , returnedColumns ) ;
43+
44+ var q = new StringBuilder ( ) ;
45+
46+ if ( options . MoveRows )
47+ {
48+ q . AppendLine ( $ """
49+ WITH moved_rows AS (
50+ DELETE FROM { tableName }
51+ RETURNING { insertedColumnList }
52+ )
53+ """ ) ;
54+ tableName = "moved_rows" ;
55+ }
56+
57+ q . AppendLine ( $ """
58+ INSERT INTO { targetTableName } ({ insertedColumnList } )
59+ SELECT { insertedColumnList }
60+ FROM { tableName }
61+ """ ) ;
62+
63+ if ( onConflict is OnConflictOptions < T > onConflictTyped )
64+ {
65+ q . AppendLine ( "ON CONFLICT" ) ;
66+
67+ if ( onConflictTyped . Update != null )
68+ {
69+ if ( onConflictTyped . Match != null )
70+ {
71+ q . AppendLine ( $ "({ string . Join ( ", " , GetColumns ( onConflictTyped . Match ) . Select ( Escape ) ) } )") ;
72+ }
73+
74+ if ( onConflictTyped . Update != null )
75+ {
76+ q . AppendLine ( $ "DO UPDATE SET { string . Join ( ", " , GetUpdates ( onConflictTyped . Update ) ) } ") ;
77+ }
78+
79+ if ( onConflictTyped . Condition != null )
80+ {
81+ q . AppendLine ( $ "WHERE { onConflictTyped . Condition } ") ;
82+ }
83+ }
84+ else
85+ {
86+ q . AppendLine ( "DO NOTHING" ) ;
87+ }
88+ }
89+
90+ if ( columnList . Length != 0 )
91+ {
92+ q . AppendLine ( $ "RETURNING { columnList } ") ;
93+ }
94+
95+ q . AppendLine ( ";" ) ;
96+
97+ return q . ToString ( ) ;
98+ }
99+
100+ private IEnumerable < string > GetUpdates < T > ( Expression < Func < T , object > > update )
101+ {
102+ if ( update . Body is NewExpression { Members : not null } newExpr )
103+ {
104+ foreach ( var arg in newExpr . Arguments . Zip ( newExpr . Members , ( expr , member ) => ( expr , member ) ) )
105+ {
106+ yield return $ "{ Escape ( arg . member . Name ) } = { ToSqlExpression ( arg . expr ) } ";
107+ }
108+ }
109+ else if ( update . Body is MemberInitExpression memberInit )
110+ {
111+ foreach ( var binding in memberInit . Bindings . OfType < MemberAssignment > ( ) )
112+ {
113+ yield return $ "{ Escape ( binding . Member . Name ) } = { ToSqlExpression ( binding . Expression ) } ";
114+ }
115+ }
116+ else if ( update . Body is MemberExpression memberExpr )
117+ {
118+ yield return $ "{ Escape ( memberExpr . Member . Name ) } = { ToSqlExpression ( memberExpr ) } ";
119+ }
120+ else
121+ {
122+ throw new NotSupportedException ( "Unsupported expression type for update" ) ;
123+ }
124+ }
125+
126+ private string ToSqlExpression ( Expression expr )
127+ {
128+ switch ( expr )
129+ {
130+ case MemberExpression m :
131+ var prefix = "EXCLUDED" ;
132+ return $ "{ prefix } .{ Escape ( m . Member . Name ) } ";
133+
134+ case BinaryExpression b :
135+ var left = ToSqlExpression ( b . Left ) ;
136+ var right = ToSqlExpression ( b . Right ) ;
137+ var op = b . NodeType switch
138+ {
139+ ExpressionType . Add => b . Type == typeof ( string ) ? "||" : "+" ,
140+ ExpressionType . Subtract => "-" ,
141+ ExpressionType . Multiply => "*" ,
142+ ExpressionType . Divide => "/" ,
143+ ExpressionType . Modulo => "%" ,
144+ ExpressionType . AndAlso => "AND" ,
145+ ExpressionType . OrElse => "OR" ,
146+ ExpressionType . Equal => "=" ,
147+ ExpressionType . NotEqual => "<>" ,
148+ ExpressionType . LessThan => "<" ,
149+ ExpressionType . LessThanOrEqual => "<=" ,
150+ ExpressionType . GreaterThan => ">" ,
151+ ExpressionType . GreaterThanOrEqual => ">=" ,
152+ _ => throw new NotSupportedException ( $ "Opérateur non supporté: { b . NodeType } ")
153+ } ;
154+ return $ "({ left } { op } { right } )";
155+
156+ case ConstantExpression c :
157+ if ( c . Type == typeof ( RawSqlValue ) && c . Value != null )
158+ {
159+ return ( ( RawSqlValue ) c . Value ! ) . Sql ;
160+ }
161+
162+ if ( c . Type == typeof ( string ) ||
163+ c . Type == typeof ( Guid ) )
164+ {
165+ return $ "'{ c . Value } '";
166+ }
167+
168+ if ( c . Type == typeof ( bool ) )
169+ {
170+ return ( bool ) c . Value ! ? "TRUE" : "FALSE" ;
171+ }
172+
173+ return c . Value ? . ToString ( ) ?? "NULL" ;
174+
175+ case UnaryExpression u :
176+ if ( u . NodeType == ExpressionType . Convert )
177+ {
178+ return ToSqlExpression ( u . Operand ) ;
179+ }
180+ if ( u . NodeType == ExpressionType . Not )
181+ {
182+ return $ "NOT ({ ToSqlExpression ( u . Operand ) } )";
183+ }
184+ throw new NotSupportedException ( $ "Unary operator not supported: { u . NodeType } ") ;
185+
186+ case MethodCallExpression mce :
187+ // Supporte quelques méthodes courantes (ToLower, ToUpper, Trim, etc.)
188+ var objSql = mce . Object != null ? ToSqlExpression ( mce . Object ) : null ;
189+ var argsSql = mce . Arguments . Select ( ToSqlExpression ) . ToArray ( ) ;
190+ switch ( mce . Method . Name )
191+ {
192+ case "ToLower" :
193+ return $ "LOWER({ objSql } )";
194+ case "ToUpper" :
195+ return $ "UPPER({ objSql } )";
196+ case "Trim" :
197+ return $ "BTRIM({ objSql } )";
198+ case "Contains" when mce is { Object : not null , Arguments . Count : 1 } :
199+ return $ "{ objSql } LIKE '%' || { argsSql [ 0 ] } || '%'";
200+ case "StartsWith" when mce is { Object : not null , Arguments . Count : 1 } :
201+ return $ "{ objSql } LIKE { argsSql [ 0 ] } || '%'";
202+ case "EndsWith" when mce is { Object : not null , Arguments . Count : 1 } :
203+ return $ "{ objSql } LIKE '%' || { argsSql [ 0 ] } ";
204+ default :
205+ throw new NotSupportedException ( $ "Method not supported: { mce . Method . Name } ") ;
206+ }
207+
208+ case ParameterExpression p :
209+ return Escape ( p . Name ?? "param" ) ;
210+
211+ default :
212+ throw new NotSupportedException ( $ "Expression not supported: { expr . NodeType } ") ;
213+ }
214+ }
215+
216+ private string [ ] GetColumns < T > ( Expression < Func < T , object > > columns )
217+ {
218+ return columns . Body switch
219+ {
220+ NewExpression newExpression => newExpression . Arguments . OfType < MemberExpression > ( )
221+ . Select ( m => m . Member . Name )
222+ . ToArray ( ) ,
223+ MemberExpression memberExpression => [
224+ memberExpression . Member . Name
225+ ] ,
226+ _ => throw new NotSupportedException ( "Unsupported expression type" )
227+ } ;
228+ }
229+
27230 protected override async Task BulkImport < T > ( DbContext context , DbConnection connection , IEnumerable < T > entities ,
28231 string tableName , PropertyAccessor [ ] properties , CancellationToken ctk ) where T : class
29232 {
0 commit comments