5
5
6
6
namespace Luau . SourceGenerator ;
7
7
8
- // TODO: refactoring
9
-
10
- internal class CreateFunctionContext : IEquatable < CreateFunctionContext >
8
+ internal record CreateFunctionContext
11
9
{
12
10
public CreateFunctionMethod ? Method { get ; set ; }
13
- public required DiagnosticReporter DiagnosticReporter { get ; init ; }
14
- public required SemanticModel Model { get ; init ; }
15
-
16
- public bool Equals ( CreateFunctionContext other )
17
- {
18
- return Method == other . Method ;
19
- }
20
-
21
- public override bool Equals ( object obj )
22
- {
23
- return obj is CreateFunctionContext ctx && Equals ( ctx ) ;
24
- }
25
-
26
- public override int GetHashCode ( )
27
- {
28
- return Method == null ? 0 : Method . GetHashCode ( ) ;
29
- }
11
+ public required IgnoreEquality < DiagnosticReporter > DiagnosticReporter { get ; init ; }
12
+ public required IgnoreEquality < SemanticModel > Model { get ; init ; }
30
13
}
31
14
32
- internal class CreateFunctionMethod : IEquatable < CreateFunctionMethod >
15
+ internal record CreateFunctionMethod
33
16
{
34
- public required CreateFunctionMethodParameter [ ] Parameters { get ; init ; }
17
+ public required EquatableArray < CreateFunctionMethodParameter > Parameters { get ; init ; }
35
18
public required string ReturnTypeName { get ; init ; }
36
19
public required bool HasReturnValue { get ; init ; }
37
20
public required bool IsAsync { get ; init ; }
38
21
public required string FilePath { get ; init ; }
39
22
public required int LineNumber { get ; init ; }
40
23
41
- public bool Equals ( CreateFunctionMethod other )
42
- {
43
- return Parameters . SequenceEqual ( other . Parameters ) && ReturnTypeName == other . ReturnTypeName ;
44
- }
45
-
46
- public override bool Equals ( object obj )
47
- {
48
- return obj is CreateFunctionMethod other && Equals ( other ) ;
49
- }
50
-
51
- public override int GetHashCode ( )
52
- {
53
- if ( Parameters . Length == 0 ) return 0 ;
54
- var hashCode = Parameters [ 0 ] . GetHashCode ( ) ;
55
- for ( int i = 1 ; i < Parameters . Length ; i ++ )
56
- {
57
- hashCode ^= Parameters [ i ] . GetHashCode ( ) ;
58
- }
59
- hashCode ^= ReturnTypeName . GetHashCode ( ) ;
60
- return hashCode ;
61
- }
62
-
63
24
public static CreateFunctionMethod Create ( Location location , bool isAsync , ITypeSymbol ? returnType , CreateFunctionMethodParameter [ ] parameters )
64
25
{
65
26
var returnTypeName = returnType == null ? "void" : returnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) ;
@@ -212,12 +173,7 @@ internal static partial class GeneratedLuauStateExtensions
212
173
. ToArray ( ) ;
213
174
214
175
var returnType = methodSymbol . ReturnsVoid ? null : methodSymbol . ReturnType ;
215
- var isAsync = returnType != null && (
216
- returnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) == "global::System.Threading.Tasks.Task" ||
217
- returnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) == "global::System.Threading.Tasks.ValueTask" ||
218
- returnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) == "global::Cysharp.Threading.Tasks.UniTask" ||
219
- returnType . ToDisplayString ( SymbolDisplayFormat . FullyQualifiedFormat ) == "global::UnityEngine.Awaitable"
220
- ) ;
176
+ var isAsync = returnType . IsTaskType ( ) ;
221
177
222
178
result . Method = CreateFunctionMethod . Create ( actionExpression . GetLocation ( ) , isAsync , returnType , parameters ) ;
223
179
}
@@ -244,9 +200,9 @@ void EmitRegisterFunctionMethod(SourceProductionContext context, ImmutableArray<
244
200
foreach ( var methodContext in methodContexts )
245
201
{
246
202
// check compilation errors
247
- if ( methodContext . DiagnosticReporter . HasDiagnostics )
203
+ if ( methodContext . DiagnosticReporter . Value . HasDiagnostics )
248
204
{
249
- methodContext . DiagnosticReporter . ReportToContext ( context ) ;
205
+ methodContext . DiagnosticReporter . Value . ReportToContext ( context ) ;
250
206
continue ;
251
207
}
252
208
0 commit comments