My source generator needs to find other classes that derive from the inspected class to know whether special code needs to be added. So far, I have only inspected the class to augment itself and followed references from there, like the type of class properties, but not any other classes that also exist in the project that uses the source generator. I can’t seem to find any methods or interfaces or web documentation for this. Is this even possible and how would that work?
I’m looking for something like this:
public static IEnumerable<string> GetDerivedTypes(this ITypeSymbol typeSymbol)
{
// TODO: Find all available classes,
// then I can proceed with inheritance checks and further tests
// Following is made-up code:
var derivedTypeNames = typeSymbol.ContainingAssembly.AllTypes
.Where(t => t.IsDerivedFrom(typeSymbol))
.Select(t => t.Name);
return derivedTypeNames;
}
2
// Define your base class (e.g., Base)
var baseClassName = "Base";
// Create a SyntaxReceiver
var syntaxReceiver = new MySyntaxReceiver();
// Visit the syntax tree
syntaxReceiver.Visit(syntaxTree.GetRoot());
// Get the derived classes
var derivedClasses = syntaxReceiver.Classes
.Where(classDeclaration =>
classDeclaration.BaseList?.Types.Any(baseType =>
baseType.ToString() == baseClassName) ?? false)
.ToList();
or
// Create a semantic model
var semanticModel = compilation.GetSemanticModel(syntaxTree);
// Get the base class symbol
var baseClassSymbol = semanticModel.Compilation.GetTypeByMetadataName(baseClassName);
// Find derived types
var derivedTypes = semanticModel.Compilation.GetSymbolsWithName(
name => name == baseClassName,
SymbolFilter.Type)
.SelectMany(symbol => symbol.ContainingAssembly.GlobalNamespace.GetNamespaceMembers())
.SelectMany(namespaceSymbol => namespaceSymbol.GetTypeMembers())
.Where(typeSymbol => typeSymbol.BaseType?.Equals(baseClassSymbol) ?? false)
.ToList();
or
Finding all class declarations than inherit from another with Roslyn
To uncover details:
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
public class ClassSyntaxReceiver : ISyntaxReceiver
{
public List<ClassDeclarationSyntax> Classes { get; } = new List<ClassDeclarationSyntax>();
public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
if (syntaxNode is ClassDeclarationSyntax classDeclaration)
{
Classes.Add(classDeclaration);
}
}
}
public class MySourceGenerator : ISourceGenerator // to use in the source generator
{
public void Execute(GeneratorExecutionContext context)
{
var syntaxReceiver = (ClassSyntaxReceiver)context.SyntaxReceiver;
var classes = syntaxReceiver.Classes;
// Now you can process the list of classes as needed.
// For example, you can check if each class derives from a specific base class.
}
// Other methods and implementation details...
}
In the Execute method, iterate through the collected classes and check if they derive from a specific base class (e.g., Base):
foreach (var classDeclaration in classes)
{
var semanticModel = context.Compilation.GetSemanticModel(classDeclaration.SyntaxTree);
var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration);
if (classSymbol.BaseType?.Name == "Base")
{
// This class derives from the specified base class.
// Add your custom logic here.
}
}
Roslyn, how to get all Classes
Roslyn’s SyntaxReceiver – Get Classes Implementing Interface
4
I found out how this works. The answer from rotabor had a hint hidden in it somewhere. I could build the solution to the described problem around that line, after more try and error.
public static string [] GetDerivedTypes(this ITypeSymbol typeSymbol)
{
var derivedTypes = GetAllTypes(typeSymbol.ContainingAssembly.GlobalNamespace)
.Where(t => IsDerivedFrom(t, typeSymbol));
return derivedTypes
.Select(dt => dt.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
.ToArray();
static IEnumerable<INamedTypeSymbol> GetAllTypes(INamespaceSymbol nsSymbol) =>
nsSymbol.GetTypeMembers()
.Concat(nsSymbol.GetNamespaceMembers().SelectMany(GetAllTypes));
static bool IsDerivedFrom(ITypeSymbol typeSymbol, ITypeSymbol baseTypeSymbol) =>
typeSymbol.BaseType != null &&
(typeSymbol.BaseType.Equals(baseTypeSymbol, SymbolEqualityComparer.Default) ||
IsDerivedFrom(typeSymbol.BaseType, baseTypeSymbol));
}
The local helper function GetAllTypes
recursively finds all types in all namespaces for the assembly. This will not find derived types in other assemblies, I can live with this limitation. Is it efficient? I can’t tell.
The other local helper function IsDerivedFrom
then determines whether a type is derived from a base type, possibly transitively. This is also done in a recursive way. It is used to filter the set of all types in the assembly so that only derivations of the specified type remain.
I took some inspiration from functional programming here and I think it fits nicely.