From a20355bc17931014d9de1b2227c7bc7f4404e793 Mon Sep 17 00:00:00 2001
From: Greg Pomerantz <gmp@wow.st>
Date: Fri, 9 May 2025 16:44:01 -0400
Subject: [PATCH] Improve and generalize estimator interface.

---
 cmd/lb/main.go     |  12 +-
 review/estimate.go | 356 +++++++++++++++++++++++++++++++++++++--------
 2 files changed, 307 insertions(+), 61 deletions(-)

diff --git a/cmd/lb/main.go b/cmd/lb/main.go
index 2807382..7d24271 100644
--- a/cmd/lb/main.go
+++ b/cmd/lb/main.go
@@ -232,13 +232,21 @@ func main() {
 				}
 			}
 		}
-		a, b := review.FitPowerLaw(weights, reps, date, 14.0)
+		//a, b := review.FitPowerLaw(weights, reps, date, 14.0)
+		est, err := review.Fit(weights, reps, date,
+				review.WithAllModels(),
+				review.WithHalfLife(14.0),
+			)
+		if err != nil {
+			fmt.Printf("Error estimating performance: %v\n", err)
+			os.Exit(1)
+		}
 		for i := 1; i<=10; i++ {
 			adj := 0.0
 			if isbw {
 				adj = ps[len(ps)-1].Bodyweight
 			}
-			fmt.Printf("%d: %0.0f\n", i, review.EstimateMaxWeight(a, b, float64(i)) - adj)
+			fmt.Printf("%d: %0.0f\n", i, est.EstimateMaxWeight(float64(i)) - adj)
 		}
 	case "predict":
 		if flag.NArg() != 3 {
diff --git a/review/estimate.go b/review/estimate.go
index f022784..f2d028c 100644
--- a/review/estimate.go
+++ b/review/estimate.go
@@ -1,75 +1,313 @@
 package review
 
 import (
-    "math"
-    "time"
+	"errors"
+	"fmt"
+	"math"
+	"time"
 
-    "gonum.org/v1/gonum/optimize"
+	"gonum.org/v1/gonum/optimize"
 )
 
-// PowerLawFunc models weight as a function of reps: w = a * reps^b
-func PowerLawFunc(a, b, reps float64) float64 {
-    return a * math.Pow(reps, b)
+// Estimator encapsulates a fitted model and exposes estimation methods.
+type Estimator interface {
+	Estimate1RM() float64
+	EstimateReps(targetWeight float64) float64
+	EstimateMaxWeight(nReps float64) float64
+	ModelType() string
+	Params() []float64
 }
 
-// WeightedResiduals computes weighted residuals for curve fitting
-func WeightedResiduals(params []float64, weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) float64 {
-    a, b := params[0], params[1]
-    var sum float64
-    for i := range weight {
-        // Exponential time decay weighting
-        daysAgo := now.Sub(dates[i]).Hours() / 24
-        weightDecay := math.Exp(-math.Ln2 * daysAgo / halfLifeDays) // Half-life decay
-        predicted := PowerLawFunc(a, b, reps[i])
-        residual := weight[i] - predicted
-        sum += weightDecay * residual * residual
-    }
-    return sum
+// Supported model types
+const (
+	ModelPowerLaw   = "powerlaw"
+	ModelLinear     = "linear"
+	ModelExponential = "exponential"
+)
+
+// FitOption is a functional option for configuring the Fit process.
+type FitOption func(*fitConfig)
+
+// fitConfig holds configuration for fitting.
+type fitConfig struct {
+	modelTypes   []string
+	halfLifeDays float64
 }
 
-// FitPowerLaw fits the power law curve with time weighting
-func FitPowerLaw(weight, reps []float64, dates []time.Time, halfLifeDays float64) (a, b float64) {
-    now := time.Now()
-    // Initial guess: a = max(weight), b = -0.1
-    params := []float64{max(weight), -0.1}
-    problem := optimize.Problem{
-        Func: func(x []float64) float64 {
-            return WeightedResiduals(x, weight, reps, dates, now, halfLifeDays)
-        },
-    }
-    result, err := optimize.Minimize(problem, params, nil, nil)
-    if err != nil {
-        panic(err)
-    }
-    return result.X[0], result.X[1]
+// WithModel specifies which model(s) to fit. If multiple, Fit selects the best.
+func WithModel(models ...string) FitOption {
+	return func(cfg *fitConfig) {
+		cfg.modelTypes = models
+	}
 }
 
+// WithAllModels configures Fit to try all built-in model types.
+func WithAllModels() FitOption {
+	return func(cfg *fitConfig) {
+		cfg.modelTypes = []string{ModelPowerLaw, ModelLinear, ModelExponential}
+	}
+}
+
+// WithHalfLife sets the half-life (in days) for time weighting.
+func WithHalfLife(days float64) FitOption {
+	return func(cfg *fitConfig) {
+		cfg.halfLifeDays = days
+	}
+}
+
+// WithModelSelection is an alias for WithModel, for clarity.
+func WithModelSelection(models []string) FitOption {
+	return WithModel(models...)
+}
+
+// Default settings
+const defaultHalfLife = 30.0
+var defaultModelTypes = []string{ModelPowerLaw}
+
+// Fit fits the specified model(s) to the data and returns an Estimator.
+// If multiple models are specified, Fit selects the best based on residual sum of squares.
+func Fit(weight, reps []float64, dates []time.Time, opts ...FitOption) (Estimator, error) {
+	if len(weight) != len(reps) || len(weight) != len(dates) {
+		return nil, errors.New("weight, reps, and dates must have the same length")
+	}
+	if len(weight) < 2 {
+		return nil, errors.New("at least two data points are required")
+	}
+
+	// Apply options
+	cfg := &fitConfig{
+		modelTypes:   defaultModelTypes,
+		halfLifeDays: defaultHalfLife,
+	}
+	for _, opt := range opts {
+		opt(cfg)
+	}
+	if len(cfg.modelTypes) == 0 {
+		cfg.modelTypes = defaultModelTypes
+	}
+
+	// Fit each model and select the best (lowest residual)
+	var best Estimator
+	var bestResidual float64 = math.Inf(1)
+	now := time.Now()
+
+	for _, model := range cfg.modelTypes {
+		var est Estimator
+		var residual float64
+		var err error
+
+		switch model {
+		case ModelPowerLaw:
+			est, residual, err = fitPowerLaw(weight, reps, dates, now, cfg.halfLifeDays)
+		case ModelLinear:
+			est, residual, err = fitLinear(weight, reps, dates, now, cfg.halfLifeDays)
+		case ModelExponential:
+			est, residual, err = fitExponential(weight, reps, dates, now, cfg.halfLifeDays)
+		default:
+			return nil, fmt.Errorf("unknown model type: %s", model)
+		}
+		if err != nil {
+			continue // Skip models that fail to fit
+		}
+		if residual < bestResidual {
+			best = est
+			bestResidual = residual
+		}
+	}
+	if best == nil {
+		return nil, errors.New("no model could be fitted to the data")
+	}
+	return best, nil
+}
+
+// --- Model Implementations ---
+
+// PowerLawEstimator: w = a * reps^b
+type PowerLawEstimator struct {
+	a, b        float64
+	halfLife    float64
+	modelType   string
+	residualSum float64
+}
+
+func (e *PowerLawEstimator) Estimate1RM() float64 {
+	return e.a * math.Pow(1, e.b)
+}
+func (e *PowerLawEstimator) EstimateReps(targetWeight float64) float64 {
+	if e.a == 0 || e.b == 0 {
+		return 0
+	}
+	return math.Pow(targetWeight/e.a, 1/e.b)
+}
+func (e *PowerLawEstimator) EstimateMaxWeight(nReps float64) float64 {
+	return e.a * math.Pow(nReps, e.b)
+}
+func (e *PowerLawEstimator) ModelType() string { return e.modelType }
+func (e *PowerLawEstimator) Params() []float64 { return []float64{e.a, e.b} }
+
+// LinearEstimator: w = a + b*reps
+type LinearEstimator struct {
+	a, b        float64
+	halfLife    float64
+	modelType   string
+	residualSum float64
+}
+
+func (e *LinearEstimator) Estimate1RM() float64 {
+	return e.a + e.b*1
+}
+func (e *LinearEstimator) EstimateReps(targetWeight float64) float64 {
+	if e.b == 0 {
+		return 0
+	}
+	return (targetWeight - e.a) / e.b
+}
+func (e *LinearEstimator) EstimateMaxWeight(nReps float64) float64 {
+	return e.a + e.b*nReps
+}
+func (e *LinearEstimator) ModelType() string { return e.modelType }
+func (e *LinearEstimator) Params() []float64 { return []float64{e.a, e.b} }
+
+// ExponentialEstimator: w = a * exp(b * reps)
+type ExponentialEstimator struct {
+	a, b        float64
+	halfLife    float64
+	modelType   string
+	residualSum float64
+}
+
+func (e *ExponentialEstimator) Estimate1RM() float64 {
+	return e.a * math.Exp(e.b*1)
+}
+func (e *ExponentialEstimator) EstimateReps(targetWeight float64) float64 {
+	if e.a == 0 || e.b == 0 {
+		return 0
+	}
+	return math.Log(targetWeight/e.a) / e.b
+}
+func (e *ExponentialEstimator) EstimateMaxWeight(nReps float64) float64 {
+	return e.a * math.Exp(e.b*nReps)
+}
+func (e *ExponentialEstimator) ModelType() string { return e.modelType }
+func (e *ExponentialEstimator) Params() []float64 { return []float64{e.a, e.b} }
+
+// --- Fitting Functions ---
+
+// fitPowerLaw fits w = a * reps^b
+func fitPowerLaw(weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) (Estimator, float64, error) {
+	params := []float64{max(weight), -0.1}
+	problem := optimize.Problem{
+		Func: func(x []float64) float64 {
+			return weightedResidualsPowerLaw(x, weight, reps, dates, now, halfLifeDays)
+		},
+	}
+	result, err := optimize.Minimize(problem, params, nil, nil)
+	if err != nil {
+		return nil, 0, err
+	}
+	residual := weightedResidualsPowerLaw(result.X, weight, reps, dates, now, halfLifeDays)
+	return &PowerLawEstimator{
+		a:          result.X[0],
+		b:          result.X[1],
+		halfLife:   halfLifeDays,
+		modelType:  ModelPowerLaw,
+		residualSum: residual,
+	}, residual, nil
+}
+
+func weightedResidualsPowerLaw(params, weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) float64 {
+	a, b := params[0], params[1]
+	var sum float64
+	for i := range weight {
+		daysAgo := now.Sub(dates[i]).Hours() / 24
+		weightDecay := math.Exp(-math.Ln2 * daysAgo / halfLifeDays)
+		predicted := a * math.Pow(reps[i], b)
+		residual := weight[i] - predicted
+		sum += weightDecay * residual * residual
+	}
+	return sum
+}
+
+// fitLinear fits w = a + b*reps
+func fitLinear(weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) (Estimator, float64, error) {
+	params := []float64{weight[0], 0.0}
+	problem := optimize.Problem{
+		Func: func(x []float64) float64 {
+			return weightedResidualsLinear(x, weight, reps, dates, now, halfLifeDays)
+		},
+	}
+	result, err := optimize.Minimize(problem, params, nil, nil)
+	if err != nil {
+		return nil, 0, err
+	}
+	residual := weightedResidualsLinear(result.X, weight, reps, dates, now, halfLifeDays)
+	return &LinearEstimator{
+		a:          result.X[0],
+		b:          result.X[1],
+		halfLife:   halfLifeDays,
+		modelType:  ModelLinear,
+		residualSum: residual,
+	}, residual, nil
+}
+
+func weightedResidualsLinear(params, weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) float64 {
+	a, b := params[0], params[1]
+	var sum float64
+	for i := range weight {
+		daysAgo := now.Sub(dates[i]).Hours() / 24
+		weightDecay := math.Exp(-math.Ln2 * daysAgo / halfLifeDays)
+		predicted := a + b*reps[i]
+		residual := weight[i] - predicted
+		sum += weightDecay * residual * residual
+	}
+	return sum
+}
+
+// fitExponential fits w = a * exp(b*reps)
+func fitExponential(weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) (Estimator, float64, error) {
+	params := []float64{max(weight), -0.01}
+	problem := optimize.Problem{
+		Func: func(x []float64) float64 {
+			return weightedResidualsExponential(x, weight, reps, dates, now, halfLifeDays)
+		},
+	}
+	result, err := optimize.Minimize(problem, params, nil, nil)
+	if err != nil {
+		return nil, 0, err
+	}
+	residual := weightedResidualsExponential(result.X, weight, reps, dates, now, halfLifeDays)
+	return &ExponentialEstimator{
+		a:          result.X[0],
+		b:          result.X[1],
+		halfLife:   halfLifeDays,
+		modelType:  ModelExponential,
+		residualSum: residual,
+	}, residual, nil
+}
+
+func weightedResidualsExponential(params, weight, reps []float64, dates []time.Time, now time.Time, halfLifeDays float64) float64 {
+	a, b := params[0], params[1]
+	var sum float64
+	for i := range weight {
+		daysAgo := now.Sub(dates[i]).Hours() / 24
+		weightDecay := math.Exp(-math.Ln2 * daysAgo / halfLifeDays)
+		predicted := a * math.Exp(b*reps[i])
+		residual := weight[i] - predicted
+		sum += weightDecay * residual * residual
+	}
+	return sum
+}
+
+// --- Utility Functions ---
+
 // max returns the maximum value in a slice
 func max(slice []float64) float64 {
-    m := slice[0]
-    for _, v := range slice {
-        if v > m {
-            m = v
-        }
-    }
-    return m
+	m := slice[0]
+	for _, v := range slice {
+		if v > m {
+			m = v
+		}
+	}
+	return m
 }
 
-// Estimate1RM estimates the current 1RM (reps=1) using fitted parameters
-func Estimate1RM(a, b float64) float64 {
-    return PowerLawFunc(a, b, 1)
-}
-
-// EstimateReps returns the predicted number of reps at a given weight
-func EstimateReps(a, b, targetWeight float64) float64 {
-    // Avoid division by zero or negative exponent issues
-    if a == 0 || b == 0 {
-        return 0
-    }
-    return math.Pow(targetWeight/a, 1/b)
-}
-
-// EstimateMaxWeight returns the predicted max weight for a given number of reps
-func EstimateMaxWeight(a, b, nReps float64) float64 {
-    return a * math.Pow(nReps, b)
-}