Repository Reading Site
main.go
ml-platform/inference/main.go
package main
import (
"encoding/json"
"fmt"
"log"
"math"
"net/http"
"os"
"sync"
"sync/atomic"
"time"
)
// ModelMetadata 模型元数据(从训练阶段的 JSON 文件加载)
type ModelMetadata struct {
Name string `json:"name"`
Version string `json:"version"`
Type string `json:"model_type"`
Features []string `json:"features"`
Metrics struct {
MAE float64 `json:"mae"`
RMSE float64 `json:"rmse"`
R2 float64 `json:"r2_score"`
} `json:"metrics"`
Timestamp string `json:"timestamp"`
}
// ScalerParams StandardScaler 的参数(均值和标准差)
type ScalerParams struct {
Mean []float64 `json:"mean"`
Std []float64 `json:"std"`
}
// LinearModel 线性回归模型参数
type LinearModel struct {
Weights []float64 `json:"weights"`
Intercept float64 `json:"intercept"`
}
// ModelServer 推理服务器
type ModelServer struct {
model *LinearModel
scaler *ScalerParams
metadata *ModelMetadata
loaded atomic.Bool
mu sync.RWMutex
// Prometheus 风格计数
requestsTotal atomic.Int64
requestsErrors atomic.Int64
latencySum atomic.Int64 // 微秒
latencyCount atomic.Int64
predictionSum atomic.Int64 // 乘以10000后的整数
predictionCount atomic.Int64
startTime time.Time
}
// PredictRequest 预测请求
type PredictRequest struct {
Features []float64 `json:"features"`
}
// PredictResponse 预测响应
type PredictResponse struct {
Prediction float64 `json:"prediction"`
PriceUSD string `json:"price_usd"`
ModelVersion string `json:"model_version"`
ModelType string `json:"model_type"`
LatencyMs float64 `json:"latency_ms"`
}
func NewModelServer() *ModelServer {
return &ModelServer{
startTime: time.Now(),
}
}
// LoadModel 从 JSON 文件加载模型参数
// 为什么用 JSON 而不是 ONNX?
// 对于线性回归,模型就是 y = X·W + b,用 JSON 存储 W 和 b 最简单
// Go 原生解析,不需要 CGO 和外部 C 库
// 生产中复杂模型用 ONNX Runtime,但学习阶段先理解本质
func (s *ModelServer) LoadModel(modelDir string) error {
// 加载模型参数
modelPath := modelDir + "/model_params.json"
data, err := os.ReadFile(modelPath)
if err != nil {
return fmt.Errorf("读取模型文件失败: %w", err)
}
var model LinearModel
if err := json.Unmarshal(data, &model); err != nil {
return fmt.Errorf("解析模型参数失败: %w", err)
}
s.model = &model
// 加载 scaler 参数
scalerPath := modelDir + "/scaler_params.json"
data, err = os.ReadFile(scalerPath)
if err != nil {
return fmt.Errorf("读取 scaler 文件失败: %w", err)
}
var scaler ScalerParams
if err := json.Unmarshal(data, &scaler); err != nil {
return fmt.Errorf("解析 scaler 参数失败: %w", err)
}
s.scaler = &scaler
// 加载元数据
metaPath := modelDir + "/metadata.json"
data, err = os.ReadFile(metaPath)
if err != nil {
log.Printf("元数据文件不存在,使用默认值: %v", err)
s.metadata = &ModelMetadata{Version: "unknown"}
} else {
var meta ModelMetadata
if err := json.Unmarshal(data, &meta); err != nil {
return fmt.Errorf("解析元数据失败: %w", err)
}
s.metadata = &meta
}
s.loaded.Store(true)
log.Printf("模型加载成功: weights=%d, intercept=%.4f", len(model.Weights), model.Intercept)
return nil
}
// Predict 执行预测: y = scaler(X) · W + b
func (s *ModelServer) Predict(features []float64) (float64, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if len(features) != len(s.model.Weights) {
return 0, fmt.Errorf("特征数量不匹配: 期望 %d, 收到 %d", len(s.model.Weights), len(features))
}
// 标准化: (x - mean) / std
scaled := make([]float64, len(features))
for i, x := range features {
scaled[i] = (x - s.scaler.Mean[i]) / s.scaler.Std[i]
}
// 线性预测: y = Σ(xi * wi) + b
prediction := s.model.Intercept
for i, x := range scaled {
prediction += x * s.model.Weights[i]
}
return prediction, nil
}
// handlePredict POST /predict
func (s *ModelServer) handlePredict(w http.ResponseWriter, r *http.Request) {
start := time.Now()
s.requestsTotal.Add(1)
if r.Method != http.MethodPost {
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
s.requestsErrors.Add(1)
return
}
var req PredictRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf(`{"error":"invalid request: %s"}`, err), http.StatusBadRequest)
s.requestsErrors.Add(1)
return
}
prediction, err := s.Predict(req.Features)
if err != nil {
http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err), http.StatusBadRequest)
s.requestsErrors.Add(1)
return
}
latency := time.Since(start)
s.latencySum.Add(latency.Microseconds())
s.latencyCount.Add(1)
s.predictionSum.Add(int64(prediction * 10000))
s.predictionCount.Add(1)
resp := PredictResponse{
Prediction: math.Round(prediction*10000) / 10000,
PriceUSD: fmt.Sprintf("$%.0f", prediction*100000),
ModelVersion: s.metadata.Version,
ModelType: s.metadata.Type,
LatencyMs: float64(latency.Microseconds()) / 1000.0,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// handleHealth GET /health (liveness probe)
func (s *ModelServer) handleHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"status":"alive"}`)
}
// handleReady GET /ready (readiness probe)
func (s *ModelServer) handleReady(w http.ResponseWriter, r *http.Request) {
if !s.loaded.Load() {
http.Error(w, `{"status":"not ready","reason":"model not loaded"}`, http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"status":"ready"}`)
}
// handleModelInfo GET /model/info
func (s *ModelServer) handleModelInfo(w http.ResponseWriter, r *http.Request) {
info := map[string]interface{}{
"loaded": s.loaded.Load(),
"version": s.metadata.Version,
"model_type": s.metadata.Type,
"features": s.metadata.Features,
"metrics": s.metadata.Metrics,
"uptime": time.Since(s.startTime).String(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(info)
}
// handleMetrics GET /metrics (Prometheus 格式)
func (s *ModelServer) handleMetrics(w http.ResponseWriter, r *http.Request) {
total := s.requestsTotal.Load()
errors := s.requestsErrors.Load()
latSum := s.latencySum.Load()
latCount := s.latencyCount.Load()
predSum := s.predictionSum.Load()
predCount := s.predictionCount.Load()
avgLatency := 0.0
if latCount > 0 {
avgLatency = float64(latSum) / float64(latCount) / 1000.0 // ms
}
avgPrediction := 0.0
if predCount > 0 {
avgPrediction = float64(predSum) / float64(predCount) / 10000.0
}
loaded := 0
if s.loaded.Load() {
loaded = 1
}
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "# HELP inference_requests_total Total inference requests\n")
fmt.Fprintf(w, "# TYPE inference_requests_total counter\n")
fmt.Fprintf(w, "inference_requests_total{status=\"success\"} %d\n", total-errors)
fmt.Fprintf(w, "inference_requests_total{status=\"error\"} %d\n", errors)
fmt.Fprintf(w, "# HELP inference_request_duration_ms Average request duration in milliseconds\n")
fmt.Fprintf(w, "# TYPE inference_request_duration_ms gauge\n")
fmt.Fprintf(w, "inference_request_duration_ms %.2f\n", avgLatency)
fmt.Fprintf(w, "# HELP inference_prediction_avg Average prediction value\n")
fmt.Fprintf(w, "# TYPE inference_prediction_avg gauge\n")
fmt.Fprintf(w, "inference_prediction_avg %.4f\n", avgPrediction)
fmt.Fprintf(w, "# HELP inference_model_loaded Whether model is loaded\n")
fmt.Fprintf(w, "# TYPE inference_model_loaded gauge\n")
fmt.Fprintf(w, "inference_model_loaded %d\n", loaded)
fmt.Fprintf(w, "# HELP inference_uptime_seconds Server uptime\n")
fmt.Fprintf(w, "# TYPE inference_uptime_seconds gauge\n")
fmt.Fprintf(w, "inference_uptime_seconds %.0f\n", time.Since(s.startTime).Seconds())
}
func main() {
modelDir := os.Getenv("MODEL_DIR")
if modelDir == "" {
modelDir = "/models"
}
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
server := NewModelServer()
// 加载模型
log.Printf("正在从 %s 加载模型...", modelDir)
if err := server.LoadModel(modelDir); err != nil {
log.Fatalf("模型加载失败: %v", err)
}
// 注册路由
http.HandleFunc("/predict", server.handlePredict)
http.HandleFunc("/health", server.handleHealth)
http.HandleFunc("/ready", server.handleReady)
http.HandleFunc("/model/info", server.handleModelInfo)
http.HandleFunc("/metrics", server.handleMetrics)
log.Printf("推理服务启动: :%s (模型: %s %s)", port, server.metadata.Type, server.metadata.Version)
if err := http.ListenAndServe(":"+port, nil); err != nil {
log.Fatalf("服务启动失败: %v", err)
}
}