K8s Lab 把当前仓库文档整理成一个可阅读的网页站点

Repository Reading Site

main.go

ml-platform/inference/main.go

Text Assetml-platform/inference/main.go8.6 KB2026年4月9日 13:58查看原始内容
package main

import (
	"encoding/json"
	"fmt"
	"log"
	"math"
	"net/http"
	"os"
	"sync"
	"sync/atomic"
	"time"
)

// ModelMetadata 模型元数据(从训练阶段的 JSON 文件加载)
type ModelMetadata struct {
	Name     string   `json:"name"`
	Version  string   `json:"version"`
	Type     string   `json:"model_type"`
	Features []string `json:"features"`
	Metrics  struct {
		MAE float64 `json:"mae"`
		RMSE float64 `json:"rmse"`
		R2   float64 `json:"r2_score"`
	} `json:"metrics"`
	Timestamp string `json:"timestamp"`
}

// ScalerParams StandardScaler 的参数(均值和标准差)
type ScalerParams struct {
	Mean []float64 `json:"mean"`
	Std  []float64 `json:"std"`
}

// LinearModel 线性回归模型参数
type LinearModel struct {
	Weights   []float64 `json:"weights"`
	Intercept float64   `json:"intercept"`
}

// ModelServer 推理服务器
type ModelServer struct {
	model    *LinearModel
	scaler   *ScalerParams
	metadata *ModelMetadata
	loaded   atomic.Bool
	mu       sync.RWMutex

	// Prometheus 风格计数
	requestsTotal   atomic.Int64
	requestsErrors  atomic.Int64
	latencySum      atomic.Int64 // 微秒
	latencyCount    atomic.Int64
	predictionSum   atomic.Int64 // 乘以10000后的整数
	predictionCount atomic.Int64
	startTime       time.Time
}

// PredictRequest 预测请求
type PredictRequest struct {
	Features []float64 `json:"features"`
}

// PredictResponse 预测响应
type PredictResponse struct {
	Prediction   float64 `json:"prediction"`
	PriceUSD     string  `json:"price_usd"`
	ModelVersion string  `json:"model_version"`
	ModelType    string  `json:"model_type"`
	LatencyMs    float64 `json:"latency_ms"`
}

func NewModelServer() *ModelServer {
	return &ModelServer{
		startTime: time.Now(),
	}
}

// LoadModel 从 JSON 文件加载模型参数
// 为什么用 JSON 而不是 ONNX?
// 对于线性回归,模型就是 y = X·W + b,用 JSON 存储 W 和 b 最简单
// Go 原生解析,不需要 CGO 和外部 C 库
// 生产中复杂模型用 ONNX Runtime,但学习阶段先理解本质
func (s *ModelServer) LoadModel(modelDir string) error {
	// 加载模型参数
	modelPath := modelDir + "/model_params.json"
	data, err := os.ReadFile(modelPath)
	if err != nil {
		return fmt.Errorf("读取模型文件失败: %w", err)
	}
	var model LinearModel
	if err := json.Unmarshal(data, &model); err != nil {
		return fmt.Errorf("解析模型参数失败: %w", err)
	}
	s.model = &model

	// 加载 scaler 参数
	scalerPath := modelDir + "/scaler_params.json"
	data, err = os.ReadFile(scalerPath)
	if err != nil {
		return fmt.Errorf("读取 scaler 文件失败: %w", err)
	}
	var scaler ScalerParams
	if err := json.Unmarshal(data, &scaler); err != nil {
		return fmt.Errorf("解析 scaler 参数失败: %w", err)
	}
	s.scaler = &scaler

	// 加载元数据
	metaPath := modelDir + "/metadata.json"
	data, err = os.ReadFile(metaPath)
	if err != nil {
		log.Printf("元数据文件不存在,使用默认值: %v", err)
		s.metadata = &ModelMetadata{Version: "unknown"}
	} else {
		var meta ModelMetadata
		if err := json.Unmarshal(data, &meta); err != nil {
			return fmt.Errorf("解析元数据失败: %w", err)
		}
		s.metadata = &meta
	}

	s.loaded.Store(true)
	log.Printf("模型加载成功: weights=%d, intercept=%.4f", len(model.Weights), model.Intercept)
	return nil
}

// Predict 执行预测: y = scaler(X) · W + b
func (s *ModelServer) Predict(features []float64) (float64, error) {
	s.mu.RLock()
	defer s.mu.RUnlock()

	if len(features) != len(s.model.Weights) {
		return 0, fmt.Errorf("特征数量不匹配: 期望 %d, 收到 %d", len(s.model.Weights), len(features))
	}

	// 标准化: (x - mean) / std
	scaled := make([]float64, len(features))
	for i, x := range features {
		scaled[i] = (x - s.scaler.Mean[i]) / s.scaler.Std[i]
	}

	// 线性预测: y = Σ(xi * wi) + b
	prediction := s.model.Intercept
	for i, x := range scaled {
		prediction += x * s.model.Weights[i]
	}

	return prediction, nil
}

// handlePredict POST /predict
func (s *ModelServer) handlePredict(w http.ResponseWriter, r *http.Request) {
	start := time.Now()
	s.requestsTotal.Add(1)

	if r.Method != http.MethodPost {
		http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
		s.requestsErrors.Add(1)
		return
	}

	var req PredictRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, fmt.Sprintf(`{"error":"invalid request: %s"}`, err), http.StatusBadRequest)
		s.requestsErrors.Add(1)
		return
	}

	prediction, err := s.Predict(req.Features)
	if err != nil {
		http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err), http.StatusBadRequest)
		s.requestsErrors.Add(1)
		return
	}

	latency := time.Since(start)
	s.latencySum.Add(latency.Microseconds())
	s.latencyCount.Add(1)
	s.predictionSum.Add(int64(prediction * 10000))
	s.predictionCount.Add(1)

	resp := PredictResponse{
		Prediction:   math.Round(prediction*10000) / 10000,
		PriceUSD:     fmt.Sprintf("$%.0f", prediction*100000),
		ModelVersion: s.metadata.Version,
		ModelType:    s.metadata.Type,
		LatencyMs:    float64(latency.Microseconds()) / 1000.0,
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(resp)
}

// handleHealth GET /health (liveness probe)
func (s *ModelServer) handleHealth(w http.ResponseWriter, r *http.Request) {
	w.WriteHeader(http.StatusOK)
	fmt.Fprint(w, `{"status":"alive"}`)
}

// handleReady GET /ready (readiness probe)
func (s *ModelServer) handleReady(w http.ResponseWriter, r *http.Request) {
	if !s.loaded.Load() {
		http.Error(w, `{"status":"not ready","reason":"model not loaded"}`, http.StatusServiceUnavailable)
		return
	}
	w.WriteHeader(http.StatusOK)
	fmt.Fprint(w, `{"status":"ready"}`)
}

// handleModelInfo GET /model/info
func (s *ModelServer) handleModelInfo(w http.ResponseWriter, r *http.Request) {
	info := map[string]interface{}{
		"loaded":     s.loaded.Load(),
		"version":    s.metadata.Version,
		"model_type": s.metadata.Type,
		"features":   s.metadata.Features,
		"metrics":    s.metadata.Metrics,
		"uptime":     time.Since(s.startTime).String(),
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(info)
}

// handleMetrics GET /metrics (Prometheus 格式)
func (s *ModelServer) handleMetrics(w http.ResponseWriter, r *http.Request) {
	total := s.requestsTotal.Load()
	errors := s.requestsErrors.Load()
	latSum := s.latencySum.Load()
	latCount := s.latencyCount.Load()
	predSum := s.predictionSum.Load()
	predCount := s.predictionCount.Load()

	avgLatency := 0.0
	if latCount > 0 {
		avgLatency = float64(latSum) / float64(latCount) / 1000.0 // ms
	}
	avgPrediction := 0.0
	if predCount > 0 {
		avgPrediction = float64(predSum) / float64(predCount) / 10000.0
	}

	loaded := 0
	if s.loaded.Load() {
		loaded = 1
	}

	w.Header().Set("Content-Type", "text/plain")
	fmt.Fprintf(w, "# HELP inference_requests_total Total inference requests\n")
	fmt.Fprintf(w, "# TYPE inference_requests_total counter\n")
	fmt.Fprintf(w, "inference_requests_total{status=\"success\"} %d\n", total-errors)
	fmt.Fprintf(w, "inference_requests_total{status=\"error\"} %d\n", errors)
	fmt.Fprintf(w, "# HELP inference_request_duration_ms Average request duration in milliseconds\n")
	fmt.Fprintf(w, "# TYPE inference_request_duration_ms gauge\n")
	fmt.Fprintf(w, "inference_request_duration_ms %.2f\n", avgLatency)
	fmt.Fprintf(w, "# HELP inference_prediction_avg Average prediction value\n")
	fmt.Fprintf(w, "# TYPE inference_prediction_avg gauge\n")
	fmt.Fprintf(w, "inference_prediction_avg %.4f\n", avgPrediction)
	fmt.Fprintf(w, "# HELP inference_model_loaded Whether model is loaded\n")
	fmt.Fprintf(w, "# TYPE inference_model_loaded gauge\n")
	fmt.Fprintf(w, "inference_model_loaded %d\n", loaded)
	fmt.Fprintf(w, "# HELP inference_uptime_seconds Server uptime\n")
	fmt.Fprintf(w, "# TYPE inference_uptime_seconds gauge\n")
	fmt.Fprintf(w, "inference_uptime_seconds %.0f\n", time.Since(s.startTime).Seconds())
}

func main() {
	modelDir := os.Getenv("MODEL_DIR")
	if modelDir == "" {
		modelDir = "/models"
	}
	port := os.Getenv("PORT")
	if port == "" {
		port = "8080"
	}

	server := NewModelServer()

	// 加载模型
	log.Printf("正在从 %s 加载模型...", modelDir)
	if err := server.LoadModel(modelDir); err != nil {
		log.Fatalf("模型加载失败: %v", err)
	}

	// 注册路由
	http.HandleFunc("/predict", server.handlePredict)
	http.HandleFunc("/health", server.handleHealth)
	http.HandleFunc("/ready", server.handleReady)
	http.HandleFunc("/model/info", server.handleModelInfo)
	http.HandleFunc("/metrics", server.handleMetrics)

	log.Printf("推理服务启动: :%s (模型: %s %s)", port, server.metadata.Type, server.metadata.Version)
	if err := http.ListenAndServe(":"+port, nil); err != nil {
		log.Fatalf("服务启动失败: %v", err)
	}
}