package middleware import ( "context" "encoding/json" "fmt" "net/http" "regexp" "strings" "devone.aplikasi.web.id/gitea/mario/go-ohif-proxy/internal/api/service" "devone.aplikasi.web.id/gitea/mario/go-ohif-proxy/internal/auth" "go.uber.org/zap" ) type contextKey string const ( UserIDKey contextKey = "user_id" UserRoleKey contextKey = "user_role" UserEmailKey contextKey = "user_email" ClaimsKey contextKey = "auth_claims" // Use this same key everywhere ) // WhitelistedEndpoints contains paths that can be accessed without authentication var WhitelistedEndpoints = []*regexp.Regexp{ // Study by UID regexp.MustCompile(`^/dicomWeb/studies\?.*StudyInstanceUID=.+`), // Frame endpoint regexp.MustCompile(`^/dicomWeb/studies/[^/]+/series/[^/]+/instances/[^/]+/frames/\d+$`), } // Auth middleware authenticates requests using JWT tokens func Auth(authService *service.AuthService, logger *zap.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { logger.Warn("Missing Authorization header", zap.String("path", r.URL.Path)) respondWithError(w, http.StatusUnauthorized, "missing authorization header") return } // Extract token from Bearer token bearerToken := strings.Split(authHeader, " ") if len(bearerToken) != 2 || strings.ToLower(bearerToken[0]) != "bearer" { logger.Warn("Invalid Authorization header format", zap.String("header", authHeader)) respondWithError(w, http.StatusUnauthorized, "invalid authorization format") return } token := bearerToken[1] // Validate token claims, err := authService.ValidateToken(token) if err != nil { logger.Warn("Invalid or expired token", zap.Error(err)) respondWithError(w, http.StatusUnauthorized, "invalid or expired token") return } // Check token type if claims.TokenType != "access" { logger.Warn("Invalid token type", zap.String("tokenType", claims.TokenType)) respondWithError(w, http.StatusUnauthorized, "invalid token type") return } // Add user info to request context ctx := context.WithValue(r.Context(), UserIDKey, claims.UserID) ctx = context.WithValue(ctx, UserRoleKey, claims.Role) ctx = context.WithValue(ctx, UserEmailKey, claims.Email) // Store the claims with the defined context key ctx = context.WithValue(ctx, ClaimsKey, claims) // Log successful authentication logger.Debug("Auth middleware: Token validated", zap.String("userID", claims.UserID), zap.String("role", claims.Role)) // Continue with the request next.ServeHTTP(w, r.WithContext(ctx)) }) } } // RoleRequired middleware checks if user has the required role func RoleRequired(roles ...string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check if the request path is whitelisted first path := r.URL.Path if r.URL.RawQuery != "" { path = path + "?" + r.URL.RawQuery } for _, pattern := range WhitelistedEndpoints { if pattern.MatchString(path) { // Path is whitelisted, skip role check next.ServeHTTP(w, r) return } } // Get user role from context userRole, ok := r.Context().Value(UserRoleKey).(string) if !ok { respondWithError(w, http.StatusUnauthorized, "user context not found") return } // Check if user has one of the required roles hasRole := false for _, role := range roles { if userRole == role { hasRole = true break } } if !hasRole { respondWithError(w, http.StatusForbidden, "insufficient permissions") return } // Continue with the request next.ServeHTTP(w, r) }) } } // PatientViewRestriction ensures patients can only access their own studies func PatientViewRestriction(logger *zap.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get claims from context using the defined key claimsValue := r.Context().Value(ClaimsKey) if claimsValue == nil { logger.Error("Missing claims in context - PatientViewRestriction middleware", zap.String("path", r.URL.Path), zap.String("method", r.Method)) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } claims, ok := claimsValue.(*auth.CustomClaims) if !ok { logger.Error("Invalid claims type in context", zap.String("type", fmt.Sprintf("%T", claimsValue))) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } logger.Debug("PatientViewRestriction: Got claims from context", zap.String("userID", claims.UserID), zap.String("role", claims.Role)) // Only apply restrictions to patient role if claims.Role != "patient" { // For non-patient roles, continue with the request next.ServeHTTP(w, r) return } // Parse the path to extract StudyInstanceUID if present path := r.URL.Path parts := strings.Split(path, "/") // Check if this is a study-specific request var requestedStudyUID string for i, part := range parts { if part == "studies" && i+1 < len(parts) { requestedStudyUID = parts[i+1] break } } // If there's no study UID in the path, check query parameters if requestedStudyUID == "" { queryStudyUID := r.URL.Query().Get("StudyInstanceUID") if queryStudyUID != "" { requestedStudyUID = queryStudyUID } } // If a study is being requested, verify patient has access if requestedStudyUID != "" && len(claims.StudyIUIDs) > 0 { // Check if the requested study is authorized isAuthorized := false for _, studyUID := range claims.StudyIUIDs { if studyUID == requestedStudyUID { isAuthorized = true logger.Debug("Patient authorized to access study", zap.String("userID", claims.UserID), zap.String("requestedStudy", requestedStudyUID)) break } } // If not authorized, return 403 Forbidden if !isAuthorized { logger.Warn("Patient attempted to access unauthorized study", zap.String("userID", claims.UserID), zap.String("role", claims.Role), zap.String("requestedStudy", requestedStudyUID), zap.Strings("authorizedStudies", claims.StudyIUIDs)) // Return 403 Forbidden with a clear message w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusForbidden) json.NewEncoder(w).Encode(map[string]string{ "error": "Access denied: You do not have permission to view this study", "code": "forbidden_study_access", }) return } } // Patient has access or is requesting a list (which will be filtered) next.ServeHTTP(w, r) }) } } // Helper function to respond with an error func respondWithError(w http.ResponseWriter, statusCode int, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) json.NewEncoder(w).Encode(map[string]string{"error": message}) }