logbook/predict.go
2025-05-09 14:37:57 -04:00

75 lines
2.1 KiB
Go

package logbook
import (
"errors"
"sort"
"time"
"github.com/aclements/go-moremath/fit"
)
// Predictable expresses that values can be predicted at arbitrary times.
type Predictable interface {
// Predict predicts the value of a variable at the given date (using LOESS).
// Returns the predicted value and error.
Predict(variable string, date time.Time, span float64) (float64, error)
}
// Ensure MeasurementRepository implements Predictable
var _ Predictable = (*MeasurementRepository)(nil)
// Predict uses LOESS to predict the value of a variable at a given date.
// span is the smoothing parameter for LOESS (typical values: 0.3-0.8).
func (r *MeasurementRepository) Predict(variable string, date time.Time, span float64) (float64, error) {
// 1. Get all measurements for the variable
measurements, err := r.GetByVariable(variable)
if err != nil {
return 0, err
}
if len(measurements) == 0 {
return 0, errors.New("no measurements found for variable")
}
// 2. Sort by date
sort.Slice(measurements, func(i, j int) bool {
return measurements[i].Date.Before(measurements[j].Date)
})
// 3. Prepare data for LOESS: x = seconds since epoch, y = value
xs := make([]float64, len(measurements))
ys := make([]float64, len(measurements))
for i, m := range measurements {
xs[i] = float64(m.Date.Unix())
ys[i] = m.Value
}
// 4. Fit LOESS model (degree 2 is typical, span 0.3-0.8)
// LOESS returns a function f(x float64) float64
f := fit.LOESS(xs, ys, 2, span)
if f == nil {
return 0, errors.New("LOESS fitting failed")
}
// 5. Predict for the requested date
predX := float64(date.Unix())
predY := f(predX)
return predY, nil
}
// GetByVariable returns all measurements for a variable.
func (r *MeasurementRepository) GetByVariable(variable string) ([]*Measurement, error) {
all, err := r.FindAll()
if err != nil {
return nil, err
}
var filtered []*Measurement
for _, m := range all {
if m.Variable == variable {
filtered = append(filtered, m)
}
}
return filtered, nil
}