package emitter

import (
	"context"
	"errors"
	"fmt"
	"strings"
	"sync"
	"time"

	"github.com/Masterminds/squirrel"
	"github.com/jackc/pgtype"
	"github.com/jmoiron/sqlx"
	"github.com/m4rw3r/uuid"
	"github.com/nats-io/nats.go"
	nrpc "github.com/nats-rpc/nrpc"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/rs/zerolog"
	"orus.io/orus-io/go-orusapi/database"
	"xbus.io/go-xbus/v4"
	"xbus.io/go-xbus/v4/api"
	"xbus.io/go-xbus/v4/envelope"

	"orus.io/orus-io/go-orusapi/xbus/emissionqueue"
)

var ErrIntOverflow = errors.New("integer overflow")

type EmissionStatus string

const (
	EmissionPending EmissionStatus = "pending"
	EmissionSent    EmissionStatus = "sent"
	EmissionRunning EmissionStatus = "running"
	EmissionDone    EmissionStatus = "done"
	EmissionError   EmissionStatus = "error"
)

// NewEmitter instanciates a Emitter.
func NewEmitter(
	ctx context.Context,
	db *sqlx.DB,
	actor *xbus.Actor,
	settings xbus.Settings,
	log zerolog.Logger,
) *Emitter {
	return &Emitter{
		actor: actor,
		db:    db,
		ctx:   ctx,

		bufferSize: settings.MustIntD("buffer-size", 1000),
		delay:      settings.MustIntD("wait-delay", 10),

		log: log,
	}
}

// Emitter consumes the emission queue and sends the messages to xbus.
type Emitter struct {
	actor *xbus.Actor
	db    *sqlx.DB

	ctx     context.Context
	runCtx  context.Context
	runStop func()
	wg      sync.WaitGroup

	bufferSize int
	delay      int

	log zerolog.Logger
}

// Startup starts taking messages from the queue and emit them.
func (e *Emitter) Startup() error {
	e.runCtx, e.runStop = context.WithCancel(e.ctx)

	// Start listening to process change of state that concerns us
	stateChan, sub, err := envelopeStatesSubscribeChan(
		e.actor.Client.API.ProcessState,
		e.actor.Client.GetConn(),
		e.actor.ID.String(), "*",
		e.log,
	)
	if err != nil {
		return err
	}

	sentMessages, err := e.listSent()
	if err != nil {
		if err := sub.Unsubscribe(); err != nil {
			e.log.Err(err).Msg("error unsubscribing")
		}

		return err
	}
	for _, msg := range sentMessages {
		processState, err := e.actor.Client.API.ProcessState.GetEnvelopeState(
			e.actor.ID.String(), &api.GetEnvelopeStateRequest{ID: api.UUID(msg.EnvelopeID).ToBytes()},
		)
		if err != nil {
			e.log.Err(err).Msg("cannot load a sent envelope process state. Skip to the next one")

			continue
		}
		if err := e.setEmissionStatus(api.UUID(msg.EnvelopeID), processState); err != nil {
			if err := sub.Unsubscribe(); err != nil {
				e.log.Err(err).Msg("error unsubscribing")
			}

			return err
		}
	}

	e.wg.Add(1)
	go e.observeProcessStates(stateChan, sub)

	e.wg.Add(1)
	go e.run()

	return nil
}

// Shutdown stops the routine from emitting more messages.
func (e *Emitter) Shutdown() error {
	e.runStop()
	e.wg.Wait()

	return nil
}

func (e *Emitter) load() ([]emissionqueue.Message, error) {
	if e.bufferSize < 0 {
		// no upper boundary since uint64 max overflows int
		return nil, fmt.Errorf("%w: Quantity %d", ErrIntOverflow, e.bufferSize)
	}
	query := database.SQ.
		Select("id", "msgtype", "content", "chunks").
		From(emissionqueue.TableName).
		Where(squirrel.Eq{"status": "pending"}).
		OrderBy("id ASC").
		Limit(uint64(e.bufferSize))

	tx, err := database.Begin(e.ctx, e.db)
	if err != nil {
		return nil, err
	}
	defer tx.RollbackIfOpened(e.log)

	sqh := database.NewSQLHelper(e.ctx, tx, e.log)

	var messages []emissionqueue.Message
	if err := sqh.Select(&messages, query); err != nil {
		return nil, err
	}

	return messages, nil
}

func (e *Emitter) listSent() ([]emissionqueue.Message, error) {
	if e.bufferSize < 0 {
		// no upper boundary since uint64 max overflows int
		return nil, fmt.Errorf("%w: Quantity %d", ErrIntOverflow, e.bufferSize)
	}
	query := database.SQ.
		Select("id", "envelope_id", "related_to", "msgtype", "content", "chunks").
		From(emissionqueue.TableName).
		Where(squirrel.Eq{"status": []string{"sent", "running"}}).
		OrderBy("id ASC").
		Limit(uint64(e.bufferSize))

	tx, err := database.Begin(e.ctx, e.db)
	if err != nil {
		return nil, err
	}
	defer tx.RollbackIfOpened(e.log)

	sqh := database.NewSQLHelper(e.ctx, tx, e.log)

	var messages []emissionqueue.Message
	if err := sqh.Select(&messages, query); err != nil {
		return nil, err
	}

	return messages, nil
}

func (e *Emitter) setEmissionStatus(envelopeID api.UUID, state *api.EmitterEnvelopeState) error {
	var status EmissionStatus

	switch state.GetStatus() {
	case api.Process_NOSTATUS, api.Process_INITIAL:
	case api.Process_RUNNING, api.Process_PAUSED:
		status = EmissionRunning
	case api.Process_DONE:
		status = EmissionDone
		emissionCounter.WithLabelValues(string(EmissionDone)).Add(1)
	case api.Process_ERROR:
		status = EmissionError
		emissionCounter.WithLabelValues(string(EmissionError)).Add(1)
	}
	values := map[string]interface{}{
		"status":     status,
		"process_id": uuid.UUID(state.GetProcessID()),
	}
	if status != "" {
		values["date_"+string(status)] = time.Now().UTC()
	}
	logs := make([]string, 0, len(state.GetErrors()))
	for _, log := range state.GetErrors() {
		logs = append(logs, log.GetText())
	}
	if len(logs) != 0 {
		values["log"] = strings.Join(logs, "\n")
	}

	query := database.SQ.Update(emissionqueue.TableName).Where(
		squirrel.Eq{"envelope_id": uuid.UUID(envelopeID)},
	).SetMap(values).
		Suffix("RETURNING related_to")

	tx, err := database.Begin(e.ctx, e.db)
	if err != nil {
		return err
	}
	defer tx.RollbackIfOpened(e.log)
	sqh := database.NewSQLHelper(e.ctx, tx, e.log)

	var relatedTo pgtype.TextArray

	if err := sqh.Get(&relatedTo, query); err != nil {
		return err
	}

	if len(relatedTo.Elements) != 0 && status == EmissionError {
		errlogs := zerolog.Arr()
		for _, log := range state.GetErrors() {
			errlogs = errlogs.Str(log.GetText())
		}
		relatedToArr := zerolog.Arr()
		for _, element := range relatedTo.Elements {
			relatedToArr = relatedToArr.Str(element.String)
		}
		e.log.Error().
			Array("related_to", relatedToArr).
			Str("process_id", state.GetProcessIDAsUUID().String()).
			Str("envelope_id", envelopeID.String()).
			Array("logs", errlogs).
			Msg("xbus error while processing envelope")
	}

	return tx.Commit()
}

func (e *Emitter) observeProcessStates(
	stateChan <-chan *EnvelopeStateMsg, sub *nats.Subscription,
) {
	defer e.wg.Done()
	defer func() {
		if err := sub.Unsubscribe(); err != nil {
			e.log.Err(err).Msg("error unsubscribing")
		}
	}()
	for {
		select {
		case <-e.runCtx.Done():
			return
		case state := <-stateChan:
			envelopeID, err := api.UUIDFromString(state.EnvelopeID)
			if err != nil {
				e.log.Err(err).
					Str("envelope_id", state.EnvelopeID).
					Msg("received invalid envelope_id")
			}
			if err := e.setEmissionStatus(
				envelopeID, state.State,
			); err != nil {
				e.log.Err(err).
					Str("envelope_id", state.EnvelopeID).
					Msg("error updating process status")
			}
		}
	}
}

func (e *Emitter) attemptEmit(msg *emissionqueue.Message) error {
	tx, err := database.Begin(e.ctx, e.db)
	if err != nil {
		return err
	}
	defer tx.RollbackIfOpened(e.log)
	db := database.NewSQLHelper(e.ctx, tx, e.log)

	// we need to save the envelope id before emitting because the 'running'
	// state is very fast to update.
	id, err := uuid.V4()
	if err != nil {
		return err
	}

	msg.EnvelopeID = id

	query := squirrel.Update(emissionqueue.TableName).
		Where(squirrel.Eq{"id": msg.ID}).
		Set("envelope_id", id)

	if _, err := db.Exec(query); err != nil {
		return err
	}

	var env envelope.Envelope
	if msg.Chunks.Status != pgtype.Null {
		m, w := envelope.MustNewMessageChunkWriter(msg.MsgType)
		env = envelope.MustNewEnvelope(m)
		for _, item := range msg.Chunks.Elements {
			if err := w.Write(item.Bytes); err != nil {
				return err
			}
		}
		if err := w.Close(); err != nil {
			return err
		}
	} else {
		env = envelope.NewEnvelopeWithID(
			api.UUID(msg.EnvelopeID),
			envelope.MustNewBytesMessage(msg.MsgType, msg.Content),
		)
	}

	if _, err := e.actor.Emit(context.Background(), env); err != nil {
		return err
	}

	query = squirrel.Update(emissionqueue.TableName).
		Where(squirrel.Eq{"id": msg.ID}).
		Set("envelope_id", uuid.UUID(env.ID())).
		Set("status", "sent").
		Set("date_sent", time.Now().UTC())

	if _, err := db.Exec(query); err != nil {
		return err
	}

	return tx.Commit()
}

func (e *Emitter) emit(msg *emissionqueue.Message) error {
	for {
		if err := e.attemptEmit(msg); err != nil {
			e.log.Err(err).Msg("failed to emit to xbus (will reattempt later)")
		} else {
			return nil
		}
		select {
		case <-time.After(time.Duration(e.delay) * time.Second):
		case <-e.runCtx.Done():
			return e.runCtx.Err()
		}
	}
}

func (e *Emitter) run() {
	defer e.wg.Done()
	var wait bool
	for {
		if wait {
			select {
			case <-time.After(time.Duration(e.delay) * time.Second):
			case <-e.runCtx.Done():
				return
			}
		}

		buffer, err := e.load()
		if err != nil {
			e.log.Err(err).Msg("failed to load messages from the queue")
			wait = true

			continue
		}

		if len(buffer) == 0 {
			wait = true

			continue
		}

		for i := range buffer {
			message := buffer[i]

			if err := e.emit(&message); err != nil {
				// an error on emit() is always blocking and is already logged
				return
			}
		}
	}
}

// EnvelopeStateMsg ...
type EnvelopeStateMsg struct {
	EmitterID  string
	EnvelopeID string
	State      *api.EmitterEnvelopeState
}

func envelopeStatesSubscribeChan(
	client *api.ProcessStateClient, nc *nats.Conn, emitterID, envelopeID string,
	log zerolog.Logger,
) (<-chan *EnvelopeStateMsg, *nats.Subscription, error) {
	ch := make(chan *EnvelopeStateMsg)
	subject := client.EnvelopeStatesSubject(emitterID, envelopeID)
	sub, err := nc.Subscribe(subject, func(msg *nats.Msg) {
		splitted := strings.Split(msg.Subject, ".")
		stateMsg := EnvelopeStateMsg{
			EmitterID:  splitted[4],
			EnvelopeID: splitted[5],
			State:      &api.EmitterEnvelopeState{},
		}
		err := nrpc.Unmarshal(client.Encoding, msg.Data, stateMsg.State)
		if err != nil {
			log.Err(err).Msg("ProcessStateClient.EnvelopeStatesSubscribe: Error decoding")

			return
		}
		ch <- &stateMsg
	})
	if err != nil {
		return nil, nil, err
	}

	return ch, sub, nil
}

var emissionCounter = prometheus.NewCounterVec(
	prometheus.CounterOpts{
		Name: "xbus_emission_count",
		Help: "Number of xbus emission so far",
	}, []string{"status"})

func init() {
	prometheus.MustRegister(emissionCounter)
}