Skip to content
Snippets Groups Projects
Commit c1b1ea2ca73f authored by Christophe de Vienne's avatar Christophe de Vienne
Browse files

cmd: improve & fix logging & auth & middleware api

parent 9c6409c5007b
No related branches found
No related tags found
No related merge requests found
Pipeline #120652 failed
......@@ -77,8 +77,10 @@
}
// NewTokenOptions creates a TokenOptions.
func NewTokenOptions(cookieBaseName string) *TokenOptions {
options := TokenOptions{}
func NewTokenOptions(cookieBasename string) *TokenOptions {
options := TokenOptions{
cookieBasename: cookieBasename,
}
options.ExpirationOpt = durationOption(&options.Expiration)
options.CachePurgeDelayOpt = durationOption(&options.CachePurgeDelay)
options.SecretOpt = hexBytesOption(32, &options.Secret)
......
package auth
import "net/http"
import (
"net/http"
"github.com/rs/zerolog"
)
type contextKey int
......@@ -11,6 +15,7 @@
func RequestAuthClaims[T any, PT interface {
*T
Claims
zerolog.LogObjectMarshaler
}](r *http.Request) *T {
value := r.Context().Value(contextAuthClaims)
if value == nil {
......@@ -22,3 +27,15 @@
return nil
}
func RequestAuthClaimsAsLogObjectMarshaler(r *http.Request) zerolog.LogObjectMarshaler {
value := r.Context().Value(contextAuthClaims)
if value == nil {
return nil
}
if claims, ok := value.(zerolog.LogObjectMarshaler); ok {
return claims
}
return nil
}
......@@ -9,7 +9,7 @@
// CookieMiddleware reads the authentication cookie and add a 'auth-claims' key
// in the request context. The cookie must be unique, if more than one cookie
// is found, an error is logged and a 401 error is returned
// is found, an error is logged and a 401 error is returned.
func CookieMiddleware[T any, PT interface {
*T
Claims
......@@ -25,8 +25,9 @@
if len(cookies) > 1 {
log.Warn().Int("count", len(cookies)).Msg("multiple auth cookies found")
rw.WriteHeader(http.StatusUnauthorized)
return
}
if len(cookies) == 0 {
h.ServeHTTP(rw, r)
......@@ -28,8 +29,9 @@
return
}
if len(cookies) == 0 {
h.ServeHTTP(rw, r)
return
}
......@@ -38,7 +40,7 @@
log.Err(err).Str("cookie-name", cookieName).Msg("could not parse auth cookie token")
}
r = r.WithContext(
context.WithValue(ctx, contextAuthClaims, &claims),
context.WithValue(ctx, contextAuthClaims, claims),
)
h.ServeHTTP(rw, r)
})
......
package auth
import (
"net/http"
"github.com/rs/zerolog"
)
func LogUserHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
if o := RequestAuthClaimsAsLogObjectMarshaler(r); o != nil {
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Object("identity", o)
})
} else {
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("identity", "anonymous")
})
}
next.ServeHTTP(w, r)
})
}
......@@ -10,6 +10,7 @@
"github.com/golang-migrate/migrate/v4/source"
"github.com/jessevdk/go-flags"
"github.com/jmoiron/sqlx"
"github.com/justinas/alice"
"github.com/rs/zerolog"
"orus.io/orus-io/go-orusapi"
......@@ -66,7 +67,8 @@
hasDB bool
dbSchema database.SchemaSet
middlewares []Middleware
authMiddlewares []Middleware[E]
middlewares []Middleware[E]
hasXbus bool
setupXbusActors []func(*Program[E]) []XbusActorFactory
......@@ -85,7 +87,19 @@
}
}
type Middleware interface {
Middleware(http.Handler) (http.Handler, error)
type Middleware[E any] interface {
Middleware(*Program[E]) (alice.Constructor, error)
}
type middlewareFuncNoInit[E any] func(http.Handler) http.Handler
func (f middlewareFuncNoInit[E]) Middleware(*Program[E]) (alice.Constructor, error) {
return alice.Constructor(f), nil
}
type middlewareFunc[E any] func(*Program[E]) (func(http.Handler) http.Handler, error)
func (f middlewareFunc[E]) Middleware(program *Program[E]) (alice.Constructor, error) {
return f(program)
}
......@@ -90,4 +104,4 @@
}
type middlewareFunc func(http.Handler) (http.Handler, error)
type middlewareAliceFunc[E any] func(*Program[E]) (alice.Constructor, error)
......@@ -93,5 +107,5 @@
func (f middlewareFunc) Middleware(next http.Handler) (http.Handler, error) {
return f(next)
func (f middlewareAliceFunc[E]) Middleware(program *Program[E]) (alice.Constructor, error) {
return f(program)
}
......@@ -96,9 +110,18 @@
}
type middlewareFuncNoErr func(http.Handler) http.Handler
func (f middlewareFuncNoErr) Middleware(next http.Handler) (http.Handler, error) {
return f(next), nil
func getMiddleware[E any](middleware any) Middleware[E] {
switch f := middleware.(type) {
case func(http.Handler) http.Handler:
return middlewareFuncNoInit[E](f)
case func(*Program[E]) (func(http.Handler) http.Handler, error):
return middlewareFunc[E](f)
case func(*Program[E]) (alice.Constructor, error):
return middlewareAliceFunc[E](f)
case Middleware[E]:
return f
default:
panic(fmt.Errorf("Invalid middleware type: %t", middleware))
}
}
func WithMiddleware[E any](middleware any) Option[E] {
......@@ -102,14 +125,4 @@
}
func WithMiddleware[E any](middleware any) Option[E] {
var m Middleware
switch f := middleware.(type) {
case func(http.Handler) http.Handler:
m = middlewareFuncNoErr(f)
case func(http.Handler) (http.Handler, error):
m = middlewareFunc(f)
default:
panic(fmt.Errorf("Invalid middleware type: %t", middleware))
}
return func(program *Program[E]) {
......@@ -115,5 +128,5 @@
return func(program *Program[E]) {
program.middlewares = append(program.middlewares, m)
program.middlewares = append(program.middlewares, getMiddleware[E](middleware))
}
}
......
package cmd
import (
"github.com/justinas/alice"
"orus.io/orus-io/go-orusapi/auth"
)
func WithAuthMiddleware[E any](middleware any) Option[E] {
return func(program *Program[E]) {
program.authMiddlewares = append(program.authMiddlewares, getMiddleware[E](middleware))
}
}
func WithCookieMiddleware[E any, T any, PT interface {
*T
auth.Claims
}]() Option[E] {
return WithAuthMiddleware[E](
func(program *Program[E]) (alice.Constructor, error) {
return auth.CookieMiddleware[T, PT](program.TokenOptions), nil
},
)
}
......@@ -94,18 +94,15 @@
uiOptions := UIOptions{
fs: uifs,
}
return func(program *Program[E]) {
middleware := func(next http.Handler) (http.Handler, error) {
var uiHandler http.Handler
if uiOptions.External == "" {
uiHandler = orusapi.NewSPAFileServer(
http.FS(uiOptions.fs),
program.Version.Hash,
)
} else {
u, err := url.Parse(uiOptions.External)
if err != nil {
return nil, err
}
uiHandler = httputil.NewSingleHostReverseProxy(u)
middleware := func(program *Program[E]) (func(next http.Handler) http.Handler, error) {
var uiHandler http.Handler
if uiOptions.External == "" {
uiHandler = orusapi.NewSPAFileServer(
http.FS(uiOptions.fs),
program.Version.Hash,
)
} else {
u, err := url.Parse(uiOptions.External)
if err != nil {
return nil, err
}
......@@ -111,5 +108,7 @@
}
if cfg.Prefix != "" {
uiHandler = http.StripPrefix(cfg.Prefix, uiHandler)
}
uiHandler = httputil.NewSingleHostReverseProxy(u)
}
if cfg.Prefix != "" {
uiHandler = http.StripPrefix(cfg.Prefix, uiHandler)
}
......@@ -115,4 +114,5 @@
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if cfg.Prefix != "" {
if !strings.HasPrefix(r.URL.Path, cfg.Prefix) {
......@@ -137,7 +137,7 @@
}
}
uiHandler.ServeHTTP(rw, r)
}), nil
}
WithMiddleware[E](middleware)(program)
})
}, nil
}
......@@ -143,4 +143,6 @@
return CombineOptions(
WithMiddleware[E](middleware),
PostInit(func(program *Program[E]) {
var serveFound bool
for _, cmd := range program.Parser.Commands() {
......@@ -158,6 +160,6 @@
if !serveFound {
panic("serve command not found")
}
})(program)
}
}),
)
}
......@@ -3,4 +3,6 @@
import (
"net/http"
"github.com/justinas/alice"
"github.com/rs/zerolog"
"orus.io/orus-io/go-orusapi"
......@@ -6,4 +8,5 @@
"orus.io/orus-io/go-orusapi"
"orus.io/orus-io/go-orusapi/auth"
)
//nolint:lll
......@@ -41,5 +44,5 @@
defer cmd.program.CloseDB()
handler := cmd.program.setupHandler(cmd.program)
stack := alice.New(orusapi.LogStack(cmd.program.Logger, zerolog.WarnLevel)...)
......@@ -45,8 +48,6 @@
for i := len(cmd.program.middlewares) - 1; i >= 0; i-- {
//for i := range cmd.program.middlewares {
m := cmd.program.middlewares[i]
h, err := m.Middleware(handler)
for _, middleware := range cmd.program.authMiddlewares {
m, err := middleware.Middleware(cmd.program)
if err != nil {
return err
}
......@@ -50,5 +51,5 @@
if err != nil {
return err
}
handler = h
stack = stack.Append(m)
}
......@@ -54,4 +55,14 @@
}
stack = stack.Append(auth.LogUserHandler)
for _, middleware := range cmd.program.middlewares {
m, err := middleware.Middleware(cmd.program)
if err != nil {
return err
}
stack = stack.Append(m)
}
handler := stack.Then(cmd.program.setupHandler(cmd.program))
cmd.Server.SetLog(cmd.program.Logger)
cmd.Server.Environment = cmd.program.InfoOptions.Environment
......
......@@ -9,7 +9,31 @@
"github.com/rs/zerolog/hlog"
)
func accessHandler(status4XXLogLevel zerolog.Level) func(r *http.Request, status, size int, duration time.Duration) {
return func(r *http.Request, status, size int, duration time.Duration) {
log := hlog.FromRequest(r)
var event *zerolog.Event
//nolint:zerologlint
switch {
case status >= http.StatusInternalServerError:
event = log.Error()
case status >= http.StatusBadRequest:
event = log.WithLevel(status4XXLogLevel)
default:
event = log.Info()
}
event.
Dict("request", zerolog.Dict().
Str("method", r.Method).
Str("url", r.URL.String())).
Int("status", status).
Int("size", size).
Dur("duration", duration).
Msg(http.StatusText(status))
}
}
// LogStack ...
func LogStack(log zerolog.Logger, status4XXLogLevel zerolog.Level) []alice.Constructor {
return []alice.Constructor{
hlog.NewHandler(log),
......@@ -12,9 +36,9 @@
// LogStack ...
func LogStack(log zerolog.Logger, status4XXLogLevel zerolog.Level) []alice.Constructor {
return []alice.Constructor{
hlog.NewHandler(log),
CatchPanics,
hlog.AccessHandler(accessHandler(status4XXLogLevel)),
hlog.RemoteAddrHandler("ip"),
hlog.UserAgentHandler("user_agent"),
hlog.RefererHandler("referer"),
hlog.RequestIDHandler("req_id", "Request-Id"),
......@@ -17,27 +41,7 @@
hlog.RemoteAddrHandler("ip"),
hlog.UserAgentHandler("user_agent"),
hlog.RefererHandler("referer"),
hlog.RequestIDHandler("req_id", "Request-Id"),
hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
log := hlog.FromRequest(r)
var event *zerolog.Event
//nolint:zerologlint
switch {
case status >= http.StatusInternalServerError:
event = log.Error()
case status >= http.StatusBadRequest:
event = log.WithLevel(status4XXLogLevel)
default:
event = log.Info()
}
event.
Dict("request", zerolog.Dict().
Str("method", r.Method).
Str("url", r.URL.String())).
Int("status", status).
Int("size", size).
Dur("duration", duration).
Msg(http.StatusText(status))
}),
CatchPanics,
}
}
......@@ -92,6 +92,8 @@
Environment string `no-flag:"t" description:"A environment name"`
AuthMiddlewares []alice.Constructor `no-flag:"t"`
api API
log zerolog.Logger
handler http.Handler
......@@ -456,6 +458,4 @@
// SetHandler allows for setting a http handler on this server.
func (s *Server) SetHandler(handler http.Handler) {
stack := LogStack(s.log, zerolog.ErrorLevel)
if s.Prometheus {
......@@ -461,4 +461,4 @@
if s.Prometheus {
stack = append(stack, Prometheus("/metrics"))
handler = Prometheus("/metrics")(handler)
}
......@@ -463,6 +463,6 @@
}
s.handler = alice.New(stack...).Then(handler)
s.handler = handler
}
// UnixListener returns the domain socket listener.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment