Skip to content
Snippets Groups Projects
program_ui.go 2.88 KiB
Newer Older
package cmd

import (
	"io/fs"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strings"

	"github.com/timewasted/go-accept-headers"

	"orus.io/orus-io/go-orusapi"
)

type UIOptions struct {
	fs    fs.FS
	paths []string

	External string `long:"ui-external" ini-name:"ui-ui-external" description:"UI external server"`
}

func NewIgnoreNotFoundResponseWriter(rw http.ResponseWriter) *IgnoreNotFoundResponseWriter {
	return &IgnoreNotFoundResponseWriter{
		header: rw.Header().Clone(),
	}
}

type IgnoreNotFoundResponseWriter struct {
	header   http.Header
	notfound *bool
	next     http.ResponseWriter
}

func (rw *IgnoreNotFoundResponseWriter) NotFound() bool {
	return rw.notfound != nil && *rw.notfound
}

func (rw *IgnoreNotFoundResponseWriter) Header() http.Header {
	return rw.header
}

func (rw *IgnoreNotFoundResponseWriter) flushHeader() {
	nh := rw.next.Header()
	for k := range nh {
		if _, ok := rw.header[k]; !ok {
			nh.Del(k)
		}
	}
	for k, v := range rw.header {
		nh[k] = v
	}
}

func (rw *IgnoreNotFoundResponseWriter) WriteHeader(statusCode int) {
	notFound := statusCode == http.StatusNotFound
	rw.notfound = &notFound
	if !notFound {
		rw.flushHeader()
		rw.next.WriteHeader(statusCode)
	}
}

func (rw *IgnoreNotFoundResponseWriter) Write(data []byte) (int, error) {
	if rw.notfound == nil {
		var value bool
		rw.notfound = &value
		rw.flushHeader()
	}
	if *rw.notfound {
		return 0, nil
	}

	return rw.next.Write(data)
}

func WithUI(uifs fs.FS, prefix string) Option {
	return PostInit(func(program *Program) {
		uiOptions := UIOptions{
			fs: uifs,
		}
		var serveFound bool
		for _, cmd := range program.Parser.Commands() {
			if cmd.Name == "serve" {
				_, err := cmd.AddGroup("ui", "User Interface", &uiOptions)
				if err != nil {
					panic(err)
				}
				serveFound = true

				break
			}
		}
		if !serveFound {
			panic("serve command not found")
		}
		middleware := func(next http.Handler) (http.Handler, error) {
			var uiHandler http.Handler
			if uiOptions.External == "" {
				uiHandler = orusapi.NewSPAFileServer(
					http.FS(uiOptions.fs),
					prefix,
				)
			} else {
				u, err := url.Parse(uiOptions.External)
				if err != nil {
					return nil, err
				}
				uiHandler = httputil.NewSingleHostReverseProxy(u)
			}
			if prefix != "" {
				uiHandler = http.StripPrefix(prefix, uiHandler)
			}

			return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
				if prefix != "" {
					if !strings.HasPrefix(r.URL.Path, prefix) {
						next.ServeHTTP(rw, r)
					}
				}
				accepted, err := accept.Negotiate(
					r.Header.Get("Accept"),
					"text/html", "application/json")
				if err != nil || accepted == "application/json" {
					next.ServeHTTP(rw, r)
				} else {
					rwWrapper := NewIgnoreNotFoundResponseWriter(rw)
					uiHandler.ServeHTTP(rwWrapper, r)

					if rwWrapper.NotFound() {
						next.ServeHTTP(rw, r)
					}
				}
			}), nil
		}
		WithMiddleware(middleware)(program)
	})
}