fix(auth): harden pending oauth and backend mode flows
This commit is contained in:
@ -678,6 +678,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
// 不影响登出流程
|
||||
}
|
||||
}
|
||||
h.consumePendingOAuthSessionOnLogout(c)
|
||||
clearOAuthLogoutCookies(c)
|
||||
|
||||
response.Success(c, LogoutResponse{
|
||||
Message: "Logged out successfully",
|
||||
|
||||
@ -469,6 +469,15 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
} else if handled {
|
||||
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
|
||||
return
|
||||
} else {
|
||||
session = updatedSession
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -757,6 +757,61 @@ func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *test
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("linuxdo-complete-choice-session").
|
||||
SetIntent("login").
|
||||
SetProviderType("linuxdo").
|
||||
SetProviderKey("linuxdo").
|
||||
SetProviderSubject("linuxdo-choice-subject-1").
|
||||
SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
|
||||
SetBrowserSessionKey("linuxdo-choice-browser").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "linuxdo_user",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
oauthCompletionResponseKey: map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"redirect": "/dashboard",
|
||||
"email": "fresh@example.com",
|
||||
"resolved_email": "fresh@example.com",
|
||||
"force_email_on_signup": true,
|
||||
},
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
|
||||
c.Request = req
|
||||
|
||||
handler.CompleteLinuxDoOAuthRegistration(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
responseData := decodeJSONBody(t, recorder)
|
||||
require.Equal(t, "pending_session", responseData["auth_result"])
|
||||
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
|
||||
require.Equal(t, true, responseData["force_email_on_signup"])
|
||||
require.Empty(t, responseData["access_token"])
|
||||
|
||||
userCount, err := client.User.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
|
||||
t.Helper()
|
||||
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
|
||||
|
||||
68
backend/internal/handler/auth_oauth_logout_test.go
Normal file
68
backend/internal/handler/auth_oauth_logout_test.go
Normal file
@ -0,0 +1,68 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("logout-pending-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("logout-subject-123").
|
||||
SetBrowserSessionKey("logout-browser-session-key").
|
||||
SetResolvedEmail("logout@example.com").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
|
||||
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
|
||||
req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
|
||||
req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
|
||||
req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
|
||||
req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.Logout(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
cookies := recorder.Result().Cookies()
|
||||
for _, name := range []string{
|
||||
oauthPendingSessionCookieName,
|
||||
oauthPendingBrowserCookieName,
|
||||
oauthBindAccessTokenCookieName,
|
||||
linuxDoOAuthStateCookieName,
|
||||
oidcOAuthStateCookieName,
|
||||
wechatOAuthStateCookieName,
|
||||
wechatPaymentOAuthStateName,
|
||||
} {
|
||||
cookie := findCookie(cookies, name)
|
||||
require.NotNil(t, cookie, name)
|
||||
require.Equal(t, -1, cookie.MaxAge, name)
|
||||
require.True(t, cookie.HttpOnly, name)
|
||||
}
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
@ -310,6 +310,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildLegacyCompleteRegistrationPendingResponse(
|
||||
session *dbent.PendingAuthSession,
|
||||
forceEmailOnSignup bool,
|
||||
emailVerificationRequired bool,
|
||||
) map[string]any {
|
||||
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"adoption_required": true,
|
||||
"create_account_allowed": true,
|
||||
"force_email_on_signup": forceEmailOnSignup,
|
||||
}))
|
||||
|
||||
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
|
||||
if _, exists := completionResponse["email"]; !exists {
|
||||
completionResponse["email"] = email
|
||||
}
|
||||
if _, exists := completionResponse["resolved_email"]; !exists {
|
||||
completionResponse["resolved_email"] = email
|
||||
}
|
||||
}
|
||||
if _, exists := completionResponse["choice_reason"]; !exists {
|
||||
switch {
|
||||
case forceEmailOnSignup:
|
||||
completionResponse["choice_reason"] = "force_email_on_signup"
|
||||
case emailVerificationRequired:
|
||||
completionResponse["choice_reason"] = "email_verification_required"
|
||||
default:
|
||||
completionResponse["choice_reason"] = "third_party_signup"
|
||||
}
|
||||
}
|
||||
return completionResponse
|
||||
}
|
||||
|
||||
func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
|
||||
c *gin.Context,
|
||||
session *dbent.PendingAuthSession,
|
||||
) (*dbent.PendingAuthSession, bool, error) {
|
||||
if session == nil {
|
||||
return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
|
||||
}
|
||||
|
||||
payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
|
||||
if step := pendingSessionStringValue(payload, "step"); step != "" {
|
||||
return session, true, nil
|
||||
}
|
||||
|
||||
emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
|
||||
forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
|
||||
if !emailVerificationRequired && !forceEmailOnSignup {
|
||||
return session, false, nil
|
||||
}
|
||||
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
|
||||
updatedSession, err := updatePendingOAuthSessionProgress(
|
||||
c.Request.Context(),
|
||||
client,
|
||||
session,
|
||||
strings.TrimSpace(session.Intent),
|
||||
strings.TrimSpace(session.ResolvedEmail),
|
||||
nil,
|
||||
buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
|
||||
}
|
||||
return updatedSession, true, nil
|
||||
}
|
||||
|
||||
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
|
||||
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
|
||||
}
|
||||
@ -1272,6 +1344,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
|
||||
return svc, session, clearCookies, nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken, err := readOAuthPendingSessionCookie(c)
|
||||
if err != nil || strings.TrimSpace(sessionToken) == "" {
|
||||
return
|
||||
}
|
||||
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
|
||||
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
svc, err := h.pendingIdentityService()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
|
||||
}
|
||||
|
||||
func clearOAuthLogoutCookies(c *gin.Context) {
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
clearOAuthBindAccessTokenCookie(c, secureCookie)
|
||||
|
||||
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
|
||||
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
|
||||
|
||||
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
|
||||
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
|
||||
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
|
||||
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
|
||||
oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
|
||||
oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
|
||||
|
||||
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
|
||||
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
|
||||
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
|
||||
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
|
||||
wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
|
||||
|
||||
wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
|
||||
wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
|
||||
wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
|
||||
wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
|
||||
}
|
||||
|
||||
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
|
||||
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
|
||||
payload := gin.H{
|
||||
@ -1451,6 +1576,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
|
||||
response.BadRequest(c, "Pending oauth session provider mismatch")
|
||||
return
|
||||
|
||||
@ -1228,6 +1228,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
|
||||
handler, _ := newOAuthPendingFlowTestHandler(t, false)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
|
||||
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.Logout(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
|
||||
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
|
||||
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
|
||||
}
|
||||
|
||||
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
|
||||
ctx := context.Background()
|
||||
|
||||
@ -374,19 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
ProviderSubject: subject,
|
||||
}
|
||||
upstreamClaims := map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"issuer": issuer,
|
||||
"email_verified": emailVerified != nil && *emailVerified,
|
||||
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": subject,
|
||||
"issuer": issuer,
|
||||
"email_verified": emailVerified != nil && *emailVerified,
|
||||
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
|
||||
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
|
||||
if idClaims != nil {
|
||||
return idClaims.Name
|
||||
}
|
||||
return ""
|
||||
}(), username),
|
||||
"suggested_avatar_url": userInfoClaims.AvatarURL,
|
||||
"suggested_avatar_url": userInfoClaims.AvatarURL,
|
||||
}
|
||||
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
|
||||
upstreamClaims["compat_email"] = compatEmail
|
||||
@ -622,6 +622,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
} else if handled {
|
||||
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
|
||||
return
|
||||
} else {
|
||||
session = updatedSession
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -692,6 +692,62 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("oidc-complete-choice-session").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example.com").
|
||||
SetProviderSubject("oidc-choice-subject-1").
|
||||
SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
|
||||
SetBrowserSessionKey("oidc-choice-browser").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
"issuer": "https://issuer.example.com",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
oauthCompletionResponseKey: map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"redirect": "/dashboard",
|
||||
"email": "fresh@example.com",
|
||||
"resolved_email": "fresh@example.com",
|
||||
"force_email_on_signup": true,
|
||||
},
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
|
||||
c.Request = req
|
||||
|
||||
handler.CompleteOIDCOAuthRegistration(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
responseData := decodeJSONBody(t, recorder)
|
||||
require.Equal(t, "pending_session", responseData["auth_result"])
|
||||
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
|
||||
require.Equal(t, true, responseData["force_email_on_signup"])
|
||||
require.Empty(t, responseData["access_token"])
|
||||
|
||||
userCount, err := client.User.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
type oidcProviderFixture struct {
|
||||
Subject string
|
||||
PreferredUsername string
|
||||
|
||||
@ -525,6 +525,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
} else if handled {
|
||||
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
|
||||
return
|
||||
} else {
|
||||
session = updatedSession
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -19,7 +19,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
@ -700,7 +699,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
|
||||
require.Zero(t, count)
|
||||
}
|
||||
|
||||
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
|
||||
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
|
||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||
t.Cleanup(func() {
|
||||
@ -773,27 +772,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
|
||||
|
||||
require.Equal(t, http.StatusOK, completeRecorder.Code)
|
||||
responseData := decodeJSONBody(t, completeRecorder)
|
||||
require.NotEmpty(t, responseData["access_token"])
|
||||
require.Equal(t, "pending_session", responseData["auth_result"])
|
||||
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
|
||||
require.Equal(t, true, responseData["adoption_required"])
|
||||
require.Empty(t, responseData["access_token"])
|
||||
|
||||
userEntity, err := client.User.Query().
|
||||
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
|
||||
consumed, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(pendingSession.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "WeChat Display", userEntity.Username)
|
||||
require.Nil(t, consumed.ConsumedAt)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
userCount, err := client.User.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("wechat"),
|
||||
authidentity.ProviderKeyEQ("wechat-main"),
|
||||
authidentity.ProviderSubjectEQ("union-456"),
|
||||
).
|
||||
Only(ctx)
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userEntity.ID, identity.UserID)
|
||||
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
|
||||
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
|
||||
require.Zero(t, identityCount)
|
||||
|
||||
channel, err := client.AuthIdentityChannel.Query().
|
||||
channelCount, err := client.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ("wechat"),
|
||||
authidentitychannel.ProviderKeyEQ("wechat-main"),
|
||||
@ -801,25 +805,15 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
|
||||
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
|
||||
authidentitychannel.ChannelSubjectEQ("openid-123"),
|
||||
).
|
||||
Only(ctx)
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, identity.ID, channel.IdentityID)
|
||||
require.Equal(t, "union-456", channel.Metadata["unionid"])
|
||||
require.Zero(t, channelCount)
|
||||
|
||||
decision, err := client.IdentityAdoptionDecision.Query().
|
||||
decisionCount, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
|
||||
Only(ctx)
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decision.IdentityID)
|
||||
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||
require.True(t, decision.AdoptDisplayName)
|
||||
require.True(t, decision.AdoptAvatar)
|
||||
|
||||
consumed, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(pendingSession.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
require.Zero(t, decisionCount)
|
||||
}
|
||||
|
||||
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
|
||||
@ -981,6 +975,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
|
||||
handler, client := newWeChatOAuthTestHandler(t, false)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("wechat-complete-choice-session").
|
||||
SetIntent("login").
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat-main").
|
||||
SetProviderSubject("wechat-choice-subject-1").
|
||||
SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
|
||||
SetBrowserSessionKey("wechat-choice-browser").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "wechat_user",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
oauthCompletionResponseKey: map[string]any{
|
||||
"step": oauthPendingChoiceStep,
|
||||
"redirect": "/dashboard",
|
||||
"email": "fresh@example.com",
|
||||
"resolved_email": "fresh@example.com",
|
||||
"force_email_on_signup": true,
|
||||
},
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
completeCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
|
||||
completeCtx.Request = req
|
||||
|
||||
handler.CompleteWeChatOAuthRegistration(completeCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
responseData := decodeJSONBody(t, recorder)
|
||||
require.Equal(t, "pending_session", responseData["auth_result"])
|
||||
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
|
||||
require.Equal(t, true, responseData["force_email_on_signup"])
|
||||
require.Empty(t, responseData["access_token"])
|
||||
|
||||
userCount, err := client.User.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
|
||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||
|
||||
@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
|
||||
}
|
||||
}
|
||||
|
||||
func backendModeAllowsAuthPath(path string) bool {
|
||||
path = strings.ToLower(strings.TrimSpace(path))
|
||||
for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, suffix := range []string{
|
||||
"/auth/oauth/linuxdo/callback",
|
||||
"/auth/oauth/wechat/callback",
|
||||
"/auth/oauth/wechat/payment/callback",
|
||||
"/auth/oauth/oidc/callback",
|
||||
"/auth/oauth/linuxdo/complete-registration",
|
||||
"/auth/oauth/wechat/complete-registration",
|
||||
"/auth/oauth/oidc/complete-registration",
|
||||
"/auth/oauth/linuxdo/create-account",
|
||||
"/auth/oauth/wechat/create-account",
|
||||
"/auth/oauth/oidc/create-account",
|
||||
"/auth/oauth/linuxdo/bind-login",
|
||||
"/auth/oauth/wechat/bind-login",
|
||||
"/auth/oauth/oidc/bind-login",
|
||||
} {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Contains(path, "/auth/oauth/pending/")
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
// Allows the minimal auth surface admins still need in backend mode, including
|
||||
// OAuth callbacks and pending continuations. Handler-level backend mode checks
|
||||
// still enforce admin-only login and forbid self-service registration.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if backendModeAllowsAuthPath(c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
|
||||
@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_linuxdo_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/linuxdo/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_linuxdo_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/linuxdo/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_wechat_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/wechat/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_wechat_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/wechat/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_wechat_payment_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/wechat/payment/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_wechat_payment_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/wechat/payment/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_oidc_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/oidc/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oidc_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/oidc/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oauth_pending_exchange",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/pending/exchange",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oauth_pending_send_verify_code",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/pending/send-verify-code",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oauth_pending_create_account",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/pending/create-account",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oauth_pending_bind_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/pending/bind-login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_provider_bind_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/oidc/bind-login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_provider_create_account",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/wechat/create-account",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_legacy_complete_registration",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/linuxdo/complete-registration",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
|
||||
Reference in New Issue
Block a user