Skip to content
Snippets Groups Projects
program.go 10.49 KiB
package cmd

import (
	"errors"
	"fmt"
	"net/http"
	"os"
	"strings"

	"github.com/golang-migrate/migrate/v4/source"
	"github.com/jessevdk/go-flags"
	"github.com/jmoiron/sqlx"
	"github.com/justinas/alice"
	"github.com/rs/zerolog"

	"orus.io/orus-io/go-orusapi"
	"orus.io/orus-io/go-orusapi/auth"
	"orus.io/orus-io/go-orusapi/database"
)

var ErrInvalidConfiguration = errors.New("invalid configuration")

type ConfigFile struct {
	ConfigFile string `long:"config" short:"c" env:"CONFIG" no-ini:"t" description:"A configuration file"`
}

//nolint:lll
type InfoOptions struct {
	BaseURL     string `long:"baseurl" env:"BASEURL" ini-name:"baseurl" default:"" description:"The base url for building links"`
	Environment string `long:"environment" env:"ENVIRONMENT" ini-name:"environment" default:"default" description:"A environment name, used in sentry and prometheus"`
}

type XbusActorFactory struct {
	Name    string
	Factory any
}

type subcommand[E any] struct {
	name             string
	shortDescription string
	longDescription  string
	init             func(*Program[E]) any
}

type Program[E any] struct {
	Name            string
	Version         Version
	BootstrapParser *flags.Parser
	Parser          *flags.Parser

	Logger zerolog.Logger

	ConfigFileOption ConfigFile
	InfoOptions      InfoOptions
	LoggingOptions   *orusapi.LoggingOptions
	TokenOptions     *auth.TokenOptions
	DatabaseOptions  database.Options
	XbusOptions      XbusOptions
	RednerOptions    RednerOptions

	Ext E

	DB *sqlx.DB

	ServeCmd          *ServeCmd[E]
	MigrateCmd        *MigrateCmd[E]
	GenerateConfigCmd *GenerateConfigCmd

	hasDB    bool
	dbSchema database.SchemaSet

	authMiddlewares []Middleware[E]
	middlewares     []Middleware[E]

	hasXbus          bool
	setupXbusActors  []func(*Program[E]) []XbusActorFactory
	xbusActorNames   []string
	setupHandler     func(*Program[E]) http.Handler
	setupSubcommands []subcommand[E]

	postInit []func(*Program[E])
}

type Option[E any] func(program *Program[E])

func WithHandler[E any](factory func(*Program[E]) http.Handler) Option[E] {
	return func(program *Program[E]) {
		program.setupHandler = factory
	}
}

type Middleware[E any] interface {
	Middleware(*Program[E]) (alice.Constructor, error)
}

type middlewareFuncNoInit[E any] func(http.Handler) http.Handler

func (f middlewareFuncNoInit[E]) Middleware(*Program[E]) (alice.Constructor, error) {
	return alice.Constructor(f), nil
}

type middlewareFunc[E any] func(*Program[E]) (func(http.Handler) http.Handler, error)

func (f middlewareFunc[E]) Middleware(program *Program[E]) (alice.Constructor, error) {
	return f(program)
}

type middlewareAliceFunc[E any] func(*Program[E]) (alice.Constructor, error)

func (f middlewareAliceFunc[E]) Middleware(program *Program[E]) (alice.Constructor, error) {
	return f(program)
}

func getMiddleware[E any](middleware any) Middleware[E] {
	switch f := middleware.(type) {
	case func(http.Handler) http.Handler:
		return middlewareFuncNoInit[E](f)
	case func(*Program[E]) (func(http.Handler) http.Handler, error):
		return middlewareFunc[E](f)
	case func(*Program[E]) (alice.Constructor, error):
		return middlewareAliceFunc[E](f)
	case Middleware[E]:
		return f
	default:
		panic(fmt.Errorf("Invalid middleware type: %t", middleware))
	}
}

func WithMiddleware[E any](middleware any) Option[E] {
	return func(program *Program[E]) {
		program.middlewares = append(program.middlewares, getMiddleware[E](middleware))
	}
}

func WithXbusActors[E any](factories func(*Program[E]) []XbusActorFactory) Option[E] {
	return func(program *Program[E]) {
		program.setupXbusActors = append(program.setupXbusActors, factories)

		if !program.hasXbus {
			g, err := program.Parser.AddGroup("Xbus", "Xbus options", &program.XbusOptions)
			if err != nil {
				panic(err)
			}
			g.Namespace = "xbus"
			g.EnvNamespace = "XBUS"

			SetupXbusCmd(program)
			program.hasXbus = true
		}
	}
}

func CombineOptions[E any](options ...Option[E]) Option[E] {
	return func(program *Program[E]) {
		for _, option := range options {
			option(program)
		}
	}
}

func WithOptionsGroup[E any](
	name string, description string, getgroup func(*Program[E]) any,
) Option[E] {
	return PostInit(func(program *Program[E]) {
		if _, err := program.Parser.AddGroup(
			name, description, getgroup(program),
		); err != nil {
			panic(err)
		}
	})
}

func WithDatabase[E any](migrateSource source.Driver) Option[E] {
	return func(program *Program[E]) {
		program.hasDB = true
		program.dbSchema.Main.Source = migrateSource
	}
}

func WithSchema[E any](name string, migrateSource source.Driver) Option[E] {
	return func(program *Program[E]) {
		program.dbSchema.Extra = append(
			program.dbSchema.Extra,
			database.Schema{
				Name: name, Source: migrateSource,
			},
		)
	}
}

func WithSubcommand[E any](
	name string, shortDescription string, longDescription string,
	init func(*Program[E]) any,
) Option[E] {
	return func(program *Program[E]) {
		program.setupSubcommands = append(program.setupSubcommands, subcommand[E]{
			name: name,
			init: init,
		})
	}
}

type Version struct {
	Version    string
	APIVersion string
	Hash       string
	Branch     string
	HgTopic    string
	Build      string
}

func NewProgram(name string, version Version, options ...Option[struct{}]) *Program[struct{}] {
	return NewProgramWithExt(name, version, struct{}{}, options...)
}

func NewProgramWithExt[E any](name string, version Version, ext E, options ...Option[E]) *Program[E] {
	bootstrapParser := flags.NewNamedParser(name, flags.IgnoreUnknown)
	parser := flags.NewNamedParser(name, flags.HelpFlag|flags.PassDoubleDash)

	program := Program[E]{
		Name:            name,
		Ext:             ext,
		Version:         version,
		BootstrapParser: bootstrapParser,
		Parser:          parser,
		Logger:          orusapi.DefaultLogger(os.Stdout),
		dbSchema:        database.SchemaSet{Main: database.Schema{Name: name}},
	}

	for _, opt := range options {
		opt(&program)
	}

	program.LoggingOptions = orusapi.NewLoggingOptions(os.Stdout)

	bootstrapParser.NamespaceDelimiter = "-"
	bootstrapParser.EnvNamespaceDelimiter = "_"
	bootstrapParser.EnvNamespace = strings.ToUpper(name)

	if _, err := bootstrapParser.AddGroup(
		"Configuration", "Configuration file", &program.ConfigFileOption,
	); err != nil {
		panic(err)
	}

	{
		g, err := bootstrapParser.AddGroup("Logging", "Logging options", program.LoggingOptions)
		if err != nil {
			panic(err)
		}
		g.Namespace = "log"
		g.EnvNamespace = "LOG"
	}

	if _, err := parser.AddGroup("Configuration", "Configuration file", &program.ConfigFileOption); err != nil {
		panic(err)
	}
	if _, err := parser.AddGroup("Info", "Info options", &program.InfoOptions); err != nil {
		panic(err)
	}

	{
		g, err := parser.AddGroup("Logging", "Logging options", program.LoggingOptions)
		if err != nil {
			panic(err)
		}
		g.Namespace = "log"
		g.EnvNamespace = "LOG"
	}

	program.ServeCmd = SetupServeCmd(&program)
	program.GenerateConfigCmd = SetupGenerateConfigCmd(&program)
	_ = SetupVersionCmd(&program)

	if program.hasDB {
		g, err := parser.AddGroup("Database", "Database options", &program.DatabaseOptions)
		if err != nil {
			panic(err)
		}
		g.Namespace = "db"
		g.EnvNamespace = "DB"

		program.MigrateCmd = SetupMigrateCmd(&program)
	}

	for _, subcmd := range program.setupSubcommands {
		if _, err := parser.AddCommand(
			subcmd.name, subcmd.shortDescription, subcmd.shortDescription,
			subcmd.init(&program),
		); err != nil {
			panic(err)
		}
	}

	for _, pi := range program.postInit {
		pi(&program)
	}

	program.Parser.CommandHandler = func(command flags.Commander, args []string) error {
		program.LoggingOptions.BuildLogger()
		program.Logger = *program.LoggingOptions.Logger()

		if command == nil {
			return nil
		}

		return command.Execute(args)
	}

	return &program
}

func PostInit[E any](pi func(program *Program[E])) Option[E] {
	return func(program *Program[E]) {
		program.postInit = append(program.postInit, pi)
	}
}

func (program *Program[E]) GetSchema() database.SchemaSet {
	return program.dbSchema
}

func (program *Program[E]) ParseArgs(args []string) int {
	if _, err := program.BootstrapParser.ParseArgs(args); err != nil {
		program.Logger.Err(err).Msg("could not parse command line")

		return 1
	}

	logLevelOption := program.BootstrapParser.FindOptionByLongName("log-level")
	logFormatOption := program.BootstrapParser.FindOptionByLongName("log-format")

	program.LoggingOptions.Lock(
		logLevelOption.IsSet() && !logLevelOption.IsSetDefault(),
		logFormatOption.IsSet() && !logFormatOption.IsSetDefault(),
	)

	program.LoggingOptions.BuildLogger()
	program.Logger = *program.LoggingOptions.Logger()

	if program.ConfigFileOption.ConfigFile != "" {
		program.Logger.Debug().Str("configfile", program.ConfigFileOption.ConfigFile).Msg("parsing configuration file")
		iniParser := flags.NewIniParser(program.Parser)
		if err := iniParser.ParseFile(program.ConfigFileOption.ConfigFile); err != nil {
			program.Logger.Err(err).Msg("")

			return 1
		}
		program.LoggingOptions.BuildLogger()
		program.Logger = *program.LoggingOptions.Logger()
	}

	if _, err := program.Parser.Parse(); err != nil {
		code := 1
		var flagsErr *flags.Error
		if errors.As(err, &flagsErr) {
			if flagsErr.Type == flags.ErrHelp {
				code = 0
				// this error actually contains a help message for the user
				// so we print it on the console
				fmt.Println(err)
			} else {
				program.Logger.Error().Msg(err.Error())
			}
		} else {
			program.Logger.Err(err).Msg("")
		}

		return code
	}

	return 0
}

func (program *Program[E]) EnsureDB(automigrate bool) error {
	if !program.hasDB {
		return nil
	}
	if program.DB != nil {
		return nil
	}

	program.Logger.Debug().Msg("Connecting to the database...")

	db, err := program.DatabaseOptions.Open()
	if err != nil {
		return err
	}
	defer func() {
		if program.DB == nil {
			if err := db.Close(); err != nil {
				program.Logger.Err(err).Msg("could not close database connection")
			}
		}
	}()

	if err := db.Ping(); err != nil {
		return err
	}

	if err := program.dbSchema.IsUptodate(db.DB); err != nil {
		if automigrate && errors.Is(err, database.ErrDBNeedUpgrade) {
			program.Logger.Warn().Msg("Database is not up-to-date, will auto-migrate...")
			if err := program.dbSchema.Migrate(db.DB, program.Logger); err != nil {
				return err
			}
			program.Logger.Info().Msg("Database successfully migrated")
		} else {
			return err
		}
	}

	program.DB = db

	program.Logger.Info().Msg("Database is ready")

	return nil
}

func (program *Program[E]) CloseDB() {
	if !program.hasDB || program.DB == nil {
		return
	}
	if err := program.DB.Close(); err != nil {
		program.Logger.Err(err).Msg("could not close database connection properly")
	}
}