sq/libsq/driver/registry.go

117 lines
2.7 KiB
Go
Raw Normal View History

2020-08-06 20:58:47 +03:00
package driver
import (
"log/slog"
2020-08-06 20:58:47 +03:00
"sync"
"github.com/neilotoole/sq/libsq/core/lg/lga"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/libsq/core/errz"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/sq/libsq/source"
)
// NewRegistry returns a new Registry instance that provides
// access to driver implementations. Note that Registry
// implements Provider.
func NewRegistry(log *slog.Logger) *Registry {
2020-08-06 20:58:47 +03:00
return &Registry{
log: log,
providers: map[source.DriverType]Provider{},
2020-08-06 20:58:47 +03:00
}
}
// Registry provides access to driver implementations.
type Registry struct {
log *slog.Logger
2020-08-06 20:58:47 +03:00
mu sync.Mutex
providers map[source.DriverType]Provider
types []source.DriverType
2020-08-06 20:58:47 +03:00
}
// AddProvider registers the provider for the specified driver type.
// This method has no effect if there's already a provider for typ.
func (r *Registry) AddProvider(typ source.DriverType, p Provider) {
2020-08-06 20:58:47 +03:00
r.mu.Lock()
defer r.mu.Unlock()
if existingType, ok := r.providers[typ]; ok {
r.log.Warn("failed to add driver provider (%T) for driver type %s: provider (%T) already registered", p, typ,
existingType)
2020-08-06 20:58:47 +03:00
return
}
r.types = append(r.types, typ)
r.providers[typ] = p
}
// ProviderFor returns the provider for typ, or nil if no
// registered provider.
func (r *Registry) ProviderFor(typ source.DriverType) Provider {
2020-08-06 20:58:47 +03:00
r.mu.Lock()
defer r.mu.Unlock()
p := r.providers[typ]
return p
2020-08-06 20:58:47 +03:00
}
// DriverFor implements Provider.
func (r *Registry) DriverFor(typ source.DriverType) (Driver, error) {
2020-08-06 20:58:47 +03:00
r.mu.Lock()
defer r.mu.Unlock()
p, ok := r.providers[typ]
if !ok {
return nil, errz.Errorf("no registered driver for {%s}", typ)
2020-08-06 20:58:47 +03:00
}
return p.DriverFor(typ)
}
// SQLDriverFor for is a convenience method for getting a SQLDriver.
func (r *Registry) SQLDriverFor(typ source.DriverType) (SQLDriver, error) {
drvr, err := r.DriverFor(typ)
if err != nil {
return nil, err
}
sqlDrvr, ok := drvr.(SQLDriver)
if !ok {
return nil, errz.Errorf("driver %T is not of type %T", drvr, sqlDrvr)
}
return sqlDrvr, nil
}
2020-08-06 20:58:47 +03:00
// DriversMetadata returns metadata for each registered driver type.
func (r *Registry) DriversMetadata() []Metadata {
var md []Metadata
for _, typ := range r.types {
drv, err := r.DriverFor(typ)
if err != nil {
// Should never happen
r.log.Error("Error getting driver", lga.Type, typ, lga.Err, err)
2020-08-06 20:58:47 +03:00
continue
}
md = append(md, drv.DriverMetadata())
}
return md
}
// Drivers returns the registered drivers.
func (r *Registry) Drivers() []Driver {
var drvrs []Driver
2020-08-06 20:58:47 +03:00
for _, typ := range r.types {
drvr, err := r.DriverFor(typ)
if err != nil {
// Should never happen
r.log.Error("Error getting driver", lga.Type, typ, lga.Err, err)
2020-08-06 20:58:47 +03:00
continue
}
drvrs = append(drvrs, drvr)
}
return drvrs
}