package review

import (
	"errors"
	"math"
	"sort"
	"time"

	"gonum.org/v1/gonum/optimize"
)

// Estimator is the main interface for incremental, time-aware strength estimation.
type Estimator interface {
	Fit(weight, reps []float64, dates []time.Time) error

	Estimate1RM(times ...time.Time) float64
	EstimateReps(targetWeight float64, times ...time.Time) float64
	EstimateMaxWeight(nReps float64, times ...time.Time) float64
	Params(times ...time.Time) []float64

	TimelineEstimate1RM(interval time.Duration, starts ...time.Time) ([]time.Time, []float64)
	TimelineEstimateReps(targetWeight float64, interval time.Duration, starts ...time.Time) ([]time.Time, []float64)
	TimelineEstimateMaxWeight(nReps float64, interval time.Duration, starts ...time.Time) ([]time.Time, []float64)
}

// --- Functional Options ---

type estimatorConfig struct {
	modelType   string
	halfLife    float64
	smoothAlpha float64 // Exponential smoothing factor (0 < alpha <= 1)
}

type EstimatorOption func(*estimatorConfig)

// WithModel sets the model type (currently only "powerlaw" is implemented).
func WithModel(model string) EstimatorOption {
	return func(cfg *estimatorConfig) {
		cfg.modelType = model
	}
}

// WithHalfLife sets the half-life for time weighting in curve fitting.
func WithHalfLife(days float64) EstimatorOption {
	return func(cfg *estimatorConfig) {
		cfg.halfLife = days
	}
}

// WithSmoothingAlpha sets the exponential smoothing factor for parameter smoothing.
func WithSmoothingAlpha(alpha float64) EstimatorOption {
	return func(cfg *estimatorConfig) {
		cfg.smoothAlpha = alpha
	}
}

// --- Estimator Implementation ---

type timePoint struct {
	date time.Time
	a    float64
	b    float64
}

type estimatorImpl struct {
	cfg        estimatorConfig
	data       []timePoint // sorted by date
	smoothedA  []float64   // smoothed a for each timePoint
	smoothedB  []float64   // smoothed b for each timePoint
}

// NewEstimator creates a new Estimator with the given options.
func NewEstimator(opts ...EstimatorOption) Estimator {
	cfg := estimatorConfig{
		modelType:   "powerlaw",
		halfLife:    30.0,
		smoothAlpha: 0.3,
	}
	for _, opt := range opts {
		opt(&cfg)
	}
	return &estimatorImpl{
		cfg: cfg,
	}
}

// Fit adds new data and updates the parameter time series and smoothing.
func (e *estimatorImpl) Fit(weight, reps []float64, dates []time.Time) error {
	if len(weight) != len(reps) || len(weight) != len(dates) {
		return errors.New("weight, reps, and dates must have the same length")
	}
	// Add new data points
	for i := range weight {
		e.data = append(e.data, timePoint{
			date: dates[i],
			a:    math.NaN(), // to be filled in
			b:    math.NaN(),
		})
	}
	// Sort all data points by date
	sort.Slice(e.data, func(i, j int) bool {
		return e.data[i].date.Before(e.data[j].date)
	})

	// For each time point, fit the model to all data up to that point
	for i := range e.data {
		var w, r []float64
		var d []time.Time
		for j := 0; j <= i; j++ {
			w = append(w, weight[j])
			r = append(r, reps[j])
			d = append(d, dates[j])
		}
		a, b := fitPowerLaw(w, r, d, e.cfg.halfLife)
		e.data[i].a = a
		e.data[i].b = b
	}

	// Smooth the parameter time series
	e.smoothedA = exponentialSmoothing(extractA(e.data), e.cfg.smoothAlpha)
	e.smoothedB = exponentialSmoothing(extractB(e.data), e.cfg.smoothAlpha)
	return nil
}

// --- Estimate Functions with Variadic Time Parameter ---

func (e *estimatorImpl) Estimate1RM(times ...time.Time) float64 {
	t := e.resolveTime(times...)
	a, b := e.smoothedParamsAt(t)
	return a * math.Pow(1, b)
}

func (e *estimatorImpl) EstimateReps(targetWeight float64, times ...time.Time) float64 {
	t := e.resolveTime(times...)
	a, b := e.smoothedParamsAt(t)
	if a == 0 || b == 0 {
		return 0
	}
	return math.Pow(targetWeight/a, 1/b)
}

func (e *estimatorImpl) EstimateMaxWeight(nReps float64, times ...time.Time) float64 {
	t := e.resolveTime(times...)
	a, b := e.smoothedParamsAt(t)
	return a * math.Pow(nReps, b)
}

func (e *estimatorImpl) Params(times ...time.Time) []float64 {
	t := e.resolveTime(times...)
	a, b := e.smoothedParamsAt(t)
	return []float64{a, b}
}

// --- Timeline Functions with Variadic Start Parameter ---

func (e *estimatorImpl) TimelineEstimate1RM(interval time.Duration, starts ...time.Time) ([]time.Time, []float64) {
	return e.timelineEstimate(
		func(a, b float64) float64 { return a * math.Pow(1, b) },
		interval, starts...,
	)
}

func (e *estimatorImpl) TimelineEstimateReps(targetWeight float64, interval time.Duration, starts ...time.Time) ([]time.Time, []float64) {
	return e.timelineEstimate(
		func(a, b float64) float64 {
			if a == 0 || b == 0 {
				return 0
			}
			return math.Pow(targetWeight/a, 1/b)
		},
		interval, starts...,
	)
}

func (e *estimatorImpl) TimelineEstimateMaxWeight(nReps float64, interval time.Duration, starts ...time.Time) ([]time.Time, []float64) {
	return e.timelineEstimate(
		func(a, b float64) float64 { return a * math.Pow(nReps, b) },
		interval, starts...,
	)
}

func (e *estimatorImpl) timelineEstimate(
	f func(a, b float64) float64,
	interval time.Duration,
	starts ...time.Time,
) ([]time.Time, []float64) {
	var times []time.Time
	var values []float64
	if len(e.data) == 0 {
		return times, values
	}
	start := e.resolveTimelineStart(starts...)
	end := e.data[len(e.data)-1].date
	for t := start; !t.After(end); t = t.Add(interval) {
		a, b := e.smoothedParamsAt(t)
		times = append(times, t)
		values = append(values, f(a, b))
	}
	return times, values
}

// --- Internal Helpers ---

// resolveTime returns the correct time to use based on the variadic argument.
// Panics if more than one time is provided.
func (e *estimatorImpl) resolveTime(times ...time.Time) time.Time {
	if len(times) == 0 {
		if len(e.data) == 0 {
			return time.Time{}
		}
		return e.data[len(e.data)-1].date
	}
	if len(times) == 1 {
		return times[0]
	}
	panic("at most one time argument allowed")
}

// resolveTimelineStart returns the timeline start time based on the variadic argument.
// Panics if more than one start time is provided.
func (e *estimatorImpl) resolveTimelineStart(starts ...time.Time) time.Time {
	if len(starts) == 0 {
		if len(e.data) == 0 {
			return time.Time{}
		}
		return e.data[0].date
	}
	if len(starts) == 1 {
		return starts[0]
	}
	panic("at most one start time argument allowed")
}

func (e *estimatorImpl) smoothedParamsAt(t time.Time) (float64, float64) {
	if len(e.data) == 0 {
		return 0, 0
	}
	idx := sort.Search(len(e.data), func(i int) bool {
		return !e.data[i].date.Before(t)
	})
	if idx == 0 {
		return e.smoothedA[0], e.smoothedB[0]
	}
	if idx >= len(e.data) {
		return e.smoothedA[len(e.data)-1], e.smoothedB[len(e.data)-1]
	}
	return e.smoothedA[idx-1], e.smoothedB[idx-1]
}

func fitPowerLaw(weight, reps []float64, dates []time.Time, halfLifeDays float64) (a, b float64) {
	now := dates[len(dates)-1]
	params := []float64{max(weight), -0.1}
	problem := optimize.Problem{
		Func: func(x []float64) float64 {
			return weightedResidualsPowerLaw(x, weight, reps, dates, now, halfLifeDays)
		},
	}
	result, err := optimize.Minimize(problem, params, nil, nil)
	if err != nil {
		return 0, 0
	}
	return result.X[0], result.X[1]
}

func weightedResidualsPowerLaw(params, weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) float64 {
	a, b := params[0], params[1]
	var sum float64
	for i := range weight {
		daysAgo := now.Sub(dates[i]).Hours() / 24
		weightDecay := math.Exp(-math.Ln2 * daysAgo / halfLifeDays)
		predicted := a * math.Pow(reps[i], b)
		residual := weight[i] - predicted
		sum += weightDecay * residual * residual
	}
	return sum
}

func max(slice []float64) float64 {
	m := slice[0]
	for _, v := range slice {
		if v > m {
			m = v
		}
	}
	return m
}

func extractA(data []timePoint) []float64 {
	out := make([]float64, len(data))
	for i, d := range data {
		out[i] = d.a
	}
	return out
}

func extractB(data []timePoint) []float64 {
	out := make([]float64, len(data))
	for i, d := range data {
		out[i] = d.b
	}
	return out
}

func exponentialSmoothing(series []float64, alpha float64) []float64 {
	if len(series) == 0 {
		return nil
	}
	smoothed := make([]float64, len(series))
	smoothed[0] = series[0]
	for i := 1; i < len(series); i++ {
		smoothed[i] = alpha*series[i] + (1-alpha)*smoothed[i-1]
	}
	return smoothed
}