Skip to content
Snippets Groups Projects
program_ui.go 3.36 KiB
package cmd

import (
	"io/fs"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strings"

	"github.com/timewasted/go-accept-headers"

	"orus.io/orus-io/go-orusapi"
)

type UIOptions struct {
	fs fs.FS

	External string `long:"external" ini-name:"external" description:"UI external server"`
}

func NewIgnoreNotFoundResponseWriter(rw http.ResponseWriter) *IgnoreNotFoundResponseWriter {
	return &IgnoreNotFoundResponseWriter{
		header: rw.Header().Clone(),
		next:   rw,
	}
}

type IgnoreNotFoundResponseWriter struct {
	header   http.Header
	notfound *bool
	next     http.ResponseWriter
}

func (rw *IgnoreNotFoundResponseWriter) NotFound() bool {
	return rw.notfound != nil && *rw.notfound
}

func (rw *IgnoreNotFoundResponseWriter) Header() http.Header {
	return rw.header
}

func (rw *IgnoreNotFoundResponseWriter) flushHeader() {
	nh := rw.next.Header()
	for k := range nh {
		if _, ok := rw.header[k]; !ok {
			nh.Del(k)
		}
	}
	for k, v := range rw.header {
		nh[k] = v
	}
}

func (rw *IgnoreNotFoundResponseWriter) WriteHeader(statusCode int) {
	notFound := statusCode == http.StatusNotFound
	rw.notfound = &notFound
	if !notFound {
		rw.flushHeader()
		rw.next.WriteHeader(statusCode)
	}
}

func (rw *IgnoreNotFoundResponseWriter) Write(data []byte) (int, error) {
	if rw.notfound == nil {
		var value bool
		rw.notfound = &value
		rw.flushHeader()
	}
	if *rw.notfound {
		return 0, nil
	}

	return rw.next.Write(data)
}

type UIConfig struct {
	Name          string
	Description   string
	Prefix        string
	ExcludePrefix []string
}

func WithUI[E any](uifs fs.FS, cfg *UIConfig) Option[E] {
	if cfg == nil {
		cfg = &UIConfig{}
	}
	if cfg.Name == "" {
		cfg.Name = "ui"
	}
	if cfg.Description == "" {
		cfg.Description = "User Interface"
	}
	uiOptions := UIOptions{
		fs: uifs,
	}

	return func(program *Program[E]) {
		middleware := func(next http.Handler) (http.Handler, error) {
			var uiHandler http.Handler
			if uiOptions.External == "" {
				uiHandler = orusapi.NewSPAFileServer(
					http.FS(uiOptions.fs),
					program.Version.Hash,
				)
			} else {
				u, err := url.Parse(uiOptions.External)
				if err != nil {
					return nil, err
				}
				uiHandler = httputil.NewSingleHostReverseProxy(u)
			}
			if cfg.Prefix != "" {
				uiHandler = http.StripPrefix(cfg.Prefix, uiHandler)
			}

			return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
				if cfg.Prefix != "" {
					if !strings.HasPrefix(r.URL.Path, cfg.Prefix) {
						next.ServeHTTP(rw, r)
					}
				}
				for _, p := range cfg.ExcludePrefix {
					if strings.HasPrefix(r.URL.Path, p) {
						next.ServeHTTP(rw, r)

						return
					}
				}
				accepted, err := accept.Negotiate(
					r.Header.Get("Accept"),
					"text/html", "application/json")
				if err != nil || accepted == "application/json" {
					rwWrapper := NewIgnoreNotFoundResponseWriter(rw)
					next.ServeHTTP(rwWrapper, r)
					if !rwWrapper.NotFound() {
						return
					}
				}
				uiHandler.ServeHTTP(rw, r)
			}), nil
		}
		WithMiddleware[E](middleware)(program)

		PostInit(func(program *Program[E]) {
			var serveFound bool
			for _, cmd := range program.Parser.Commands() {
				if cmd.Name == "serve" {
					g, err := cmd.AddGroup(cfg.Name, cfg.Description, &uiOptions)
					if err != nil {
						panic(err)
					}
					g.Namespace = cfg.Name
					serveFound = true

					break
				}
			}
			if !serveFound {
				panic("serve command not found")
			}
		})(program)
	}
}