package cmd import ( "io/fs" "net/http" "net/http/httputil" "net/url" "strings" "github.com/timewasted/go-accept-headers" "orus.io/orus-io/go-orusapi" ) type UIOptions struct { fs fs.FS paths []string External string `long:"ui-external" ini-name:"ui-ui-external" description:"UI external server"` } func NewIgnoreNotFoundResponseWriter(rw http.ResponseWriter) *IgnoreNotFoundResponseWriter { return &IgnoreNotFoundResponseWriter{ header: rw.Header().Clone(), next: rw, } } type IgnoreNotFoundResponseWriter struct { header http.Header notfound *bool next http.ResponseWriter } func (rw *IgnoreNotFoundResponseWriter) NotFound() bool { return rw.notfound != nil && *rw.notfound } func (rw *IgnoreNotFoundResponseWriter) Header() http.Header { return rw.header } func (rw *IgnoreNotFoundResponseWriter) flushHeader() { nh := rw.next.Header() for k := range nh { if _, ok := rw.header[k]; !ok { nh.Del(k) } } for k, v := range rw.header { nh[k] = v } } func (rw *IgnoreNotFoundResponseWriter) WriteHeader(statusCode int) { notFound := statusCode == http.StatusNotFound rw.notfound = ¬Found if !notFound { rw.flushHeader() rw.next.WriteHeader(statusCode) } } func (rw *IgnoreNotFoundResponseWriter) Write(data []byte) (int, error) { if rw.notfound == nil { var value bool rw.notfound = &value rw.flushHeader() } if *rw.notfound { return 0, nil } return rw.next.Write(data) } func WithUI(uifs fs.FS, prefix string) Option { uiOptions := UIOptions{ fs: uifs, } return func(program *Program) { middleware := func(next http.Handler) (http.Handler, error) { var uiHandler http.Handler if uiOptions.External == "" { uiHandler = orusapi.NewSPAFileServer( http.FS(uiOptions.fs), program.Version.Hash, ) } else { u, err := url.Parse(uiOptions.External) if err != nil { return nil, err } uiHandler = httputil.NewSingleHostReverseProxy(u) } if prefix != "" { uiHandler = http.StripPrefix(prefix, uiHandler) } return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if prefix != "" { if !strings.HasPrefix(r.URL.Path, prefix) { next.ServeHTTP(rw, r) } } accepted, err := accept.Negotiate( r.Header.Get("Accept"), "text/html", "application/json") if err != nil || accepted == "application/json" { rwWrapper := NewIgnoreNotFoundResponseWriter(rw) next.ServeHTTP(rwWrapper, r) if !rwWrapper.NotFound() { return } } uiHandler.ServeHTTP(rw, r) }), nil } WithMiddleware(middleware)(program) PostInit(func(program *Program) { var serveFound bool for _, cmd := range program.Parser.Commands() { if cmd.Name == "serve" { _, err := cmd.AddGroup("ui", "User Interface", &uiOptions) if err != nil { panic(err) } serveFound = true break } } if !serveFound { panic("serve command not found") } }) } }