mirror of
https://github.com/MichaelMure/git-bug.git
synced 2024-12-15 10:12:06 +03:00
313 lines
7.1 KiB
Go
313 lines
7.1 KiB
Go
package codegen
|
|
|
|
import (
|
|
"fmt"
|
|
"go/types"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/tools/go/loader"
|
|
)
|
|
|
|
func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Object, error) {
|
|
if pkgName == "" {
|
|
return nil, nil
|
|
}
|
|
fullName := typeName
|
|
if pkgName != "" {
|
|
fullName = pkgName + "." + typeName
|
|
}
|
|
|
|
pkgName, err := resolvePkg(pkgName)
|
|
if err != nil {
|
|
return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
|
|
}
|
|
|
|
pkg := prog.Imported[pkgName]
|
|
if pkg == nil {
|
|
return nil, errors.Errorf("required package was not loaded: %s", fullName)
|
|
}
|
|
|
|
for astNode, def := range pkg.Defs {
|
|
if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
|
|
continue
|
|
}
|
|
|
|
return def, nil
|
|
}
|
|
|
|
return nil, errors.Errorf("unable to find type %s\n", fullName)
|
|
}
|
|
|
|
func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) {
|
|
def, err := findGoType(prog, pkgName, typeName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if def == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
namedType, ok := def.Type().(*types.Named)
|
|
if !ok {
|
|
return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type())
|
|
}
|
|
|
|
return namedType, nil
|
|
}
|
|
|
|
func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) {
|
|
namedType, err := findGoNamedType(prog, pkgName, typeName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if namedType == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
underlying, ok := namedType.Underlying().(*types.Interface)
|
|
if !ok {
|
|
return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String())
|
|
}
|
|
|
|
return underlying, nil
|
|
}
|
|
|
|
func findMethod(typ *types.Named, name string) *types.Func {
|
|
for i := 0; i < typ.NumMethods(); i++ {
|
|
method := typ.Method(i)
|
|
if !method.Exported() {
|
|
continue
|
|
}
|
|
|
|
if strings.EqualFold(method.Name(), name) {
|
|
return method
|
|
}
|
|
}
|
|
|
|
if s, ok := typ.Underlying().(*types.Struct); ok {
|
|
for i := 0; i < s.NumFields(); i++ {
|
|
field := s.Field(i)
|
|
if !field.Anonymous() {
|
|
continue
|
|
}
|
|
|
|
if named, ok := field.Type().(*types.Named); ok {
|
|
if f := findMethod(named, name); f != nil {
|
|
return f
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func findField(typ *types.Struct, name string) *types.Var {
|
|
for i := 0; i < typ.NumFields(); i++ {
|
|
field := typ.Field(i)
|
|
if field.Anonymous() {
|
|
if named, ok := field.Type().(*types.Struct); ok {
|
|
if f := findField(named, name); f != nil {
|
|
return f
|
|
}
|
|
}
|
|
|
|
if named, ok := field.Type().Underlying().(*types.Struct); ok {
|
|
if f := findField(named, name); f != nil {
|
|
return f
|
|
}
|
|
}
|
|
}
|
|
|
|
if !field.Exported() {
|
|
continue
|
|
}
|
|
|
|
if strings.EqualFold(field.Name(), name) {
|
|
return field
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type BindError struct {
|
|
object *Object
|
|
field *Field
|
|
typ types.Type
|
|
methodErr error
|
|
varErr error
|
|
}
|
|
|
|
func (b BindError) Error() string {
|
|
return fmt.Sprintf(
|
|
"Unable to bind %s.%s to %s\n %s\n %s",
|
|
b.object.GQLType,
|
|
b.field.GQLName,
|
|
b.typ.String(),
|
|
b.methodErr.Error(),
|
|
b.varErr.Error(),
|
|
)
|
|
}
|
|
|
|
type BindErrors []BindError
|
|
|
|
func (b BindErrors) Error() string {
|
|
var errs []string
|
|
for _, err := range b {
|
|
errs = append(errs, err.Error())
|
|
}
|
|
return strings.Join(errs, "\n\n")
|
|
}
|
|
|
|
func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
|
|
var errs BindErrors
|
|
for i := range object.Fields {
|
|
field := &object.Fields[i]
|
|
|
|
// first try binding to a method
|
|
methodErr := bindMethod(imports, t, field)
|
|
if methodErr == nil {
|
|
continue
|
|
}
|
|
|
|
// otherwise try binding to a var
|
|
varErr := bindVar(imports, t, field)
|
|
|
|
if varErr != nil {
|
|
errs = append(errs, BindError{
|
|
object: object,
|
|
typ: t,
|
|
field: field,
|
|
varErr: varErr,
|
|
methodErr: methodErr,
|
|
})
|
|
}
|
|
}
|
|
return errs
|
|
}
|
|
|
|
func bindMethod(imports *Imports, t types.Type, field *Field) error {
|
|
namedType, ok := t.(*types.Named)
|
|
if !ok {
|
|
return fmt.Errorf("not a named type")
|
|
}
|
|
|
|
method := findMethod(namedType, field.GQLName)
|
|
if method == nil {
|
|
return fmt.Errorf("no method named %s", field.GQLName)
|
|
}
|
|
sig := method.Type().(*types.Signature)
|
|
|
|
if sig.Results().Len() == 1 {
|
|
field.NoErr = true
|
|
} else if sig.Results().Len() != 2 {
|
|
return fmt.Errorf("method has wrong number of args")
|
|
}
|
|
newArgs, err := matchArgs(field, sig.Params())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result := sig.Results().At(0)
|
|
if err := validateTypeBinding(imports, field, result.Type()); err != nil {
|
|
return errors.Wrap(err, "method has wrong return type")
|
|
}
|
|
|
|
// success, args and return type match. Bind to method
|
|
field.GoMethodName = "obj." + method.Name()
|
|
field.Args = newArgs
|
|
return nil
|
|
}
|
|
|
|
func bindVar(imports *Imports, t types.Type, field *Field) error {
|
|
underlying, ok := t.Underlying().(*types.Struct)
|
|
if !ok {
|
|
return fmt.Errorf("not a struct")
|
|
}
|
|
|
|
structField := findField(underlying, field.GQLName)
|
|
if structField == nil {
|
|
return fmt.Errorf("no field named %s", field.GQLName)
|
|
}
|
|
|
|
if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
|
|
return errors.Wrap(err, "field has wrong type")
|
|
}
|
|
|
|
// success, bind to var
|
|
field.GoVarName = structField.Name()
|
|
return nil
|
|
}
|
|
|
|
func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
|
|
var newArgs []FieldArgument
|
|
|
|
nextArg:
|
|
for j := 0; j < params.Len(); j++ {
|
|
param := params.At(j)
|
|
for _, oldArg := range field.Args {
|
|
if strings.EqualFold(oldArg.GQLName, param.Name()) {
|
|
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
|
|
newArgs = append(newArgs, oldArg)
|
|
continue nextArg
|
|
}
|
|
}
|
|
|
|
// no matching arg found, abort
|
|
return nil, fmt.Errorf("arg %s not found on method", param.Name())
|
|
}
|
|
return newArgs, nil
|
|
}
|
|
|
|
func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
|
|
gqlType := normalizeVendor(field.Type.FullSignature())
|
|
goTypeStr := normalizeVendor(goType.String())
|
|
|
|
if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
|
|
field.Type.Modifiers = modifiersFromGoType(goType)
|
|
return nil
|
|
}
|
|
|
|
// deal with type aliases
|
|
underlyingStr := normalizeVendor(goType.Underlying().String())
|
|
if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
|
|
field.Type.Modifiers = modifiersFromGoType(goType)
|
|
pkg, typ := pkgAndType(goType.String())
|
|
imp := imports.findByPath(pkg)
|
|
field.CastType = &Ref{GoType: typ, Import: imp}
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
|
|
}
|
|
|
|
func modifiersFromGoType(t types.Type) []string {
|
|
var modifiers []string
|
|
for {
|
|
switch val := t.(type) {
|
|
case *types.Pointer:
|
|
modifiers = append(modifiers, modPtr)
|
|
t = val.Elem()
|
|
case *types.Array:
|
|
modifiers = append(modifiers, modList)
|
|
t = val.Elem()
|
|
case *types.Slice:
|
|
modifiers = append(modifiers, modList)
|
|
t = val.Elem()
|
|
default:
|
|
return modifiers
|
|
}
|
|
}
|
|
}
|
|
|
|
var modsRegex = regexp.MustCompile(`^(\*|\[\])*`)
|
|
|
|
func normalizeVendor(pkg string) string {
|
|
modifiers := modsRegex.FindAllString(pkg, 1)[0]
|
|
pkg = strings.TrimPrefix(pkg, modifiers)
|
|
parts := strings.Split(pkg, "/vendor/")
|
|
return modifiers + parts[len(parts)-1]
|
|
}
|