// This file is part of tdir, the Taler Directory implementation.
// Copyright (C) 2025 Martin Schanzenbach
//
// Taldir is free software: you can redistribute it and/or modify it
// under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// Taldir is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
// SPDX-License-Identifier: AGPL3.0-or-later

package taldir

import (
	"crypto/rand"
	"encoding/json"
	"fmt"
	"html/template"
	"net/http"
	"net/url"
	"regexp"
	"strings"
)

type AuthorizationsState struct {

	// Alias
	alias string

	// Challenge
	challenge string
}

type RelevantUserClaims struct {
  // Subject
	Sub string
}

type OidcValidator struct {

	// Name
	name string

	// Config
	config *TaldirConfig

	// Client ID
	clientID string

	// Client secret
	clientSecret string

	// Scope(s)
	scope string

	// Claim that is the alias
	aliasClaimName string

	// Redirect URI
	redirectURI string

	// Userinfo endpoint
	userinfoEndpoint string

	// Token endpoint
	tokenEndpoint string

	// OIDC authorization endpoint
	authorizationEndpoint string

	// registration/lookup page
	landingPageTpl *template.Template

	// Validator alias regex
	validAliasRegex string

	// State object
	// Maps states to challenge
	authorizationsState map[string]*AuthorizationsState
}

type OAuthTokenResponse struct {
	// AccessToken
	AccessToken string `json:"access_token"`

	// Token type
	TokenType string `json:"token_type"`

	// Expiration
	ExpiresIn int `json:"expires_in"`
}

func (t OidcValidator) LandingPageTpl() *template.Template {
	return t.landingPageTpl
}

func (t OidcValidator) Type() ValidatorType {
	return ValidatorTypeOIDC
}

func (t OidcValidator) Name() string {
	return t.name
}

func (t OidcValidator) ChallengeFee() string {
	return t.config.Ini.Section("taldir-validator-" + t.name).Key("challenge_fee").MustString("KUDOS:0")
}

func (t OidcValidator) IsAliasValid(alias string) (err error) {
	if t.validAliasRegex != "" {
		matched, _ := regexp.MatchString(t.validAliasRegex, alias)
		if !matched {
			return fmt.Errorf("alias `%s' invalid", alias) // TODO i18n
		}
	}
	return
}



func (t OidcValidator) ValidateAliasSubject(tokenString string, expectedAlias string) error {
	var relevantClaims map[string]any
	req, err := http.NewRequest("GET", t.userinfoEndpoint, nil)
	if err != nil {
		return fmt.Errorf("failed to create userinfo request")
	}
	req.Header.Set("Authorization", "Bearer " + tokenString)
	client := &http.Client{}
	resp, err := client.Do(req)
	if err != nil {
		return fmt.Errorf("failed to execute userinfo request")
	}
	if resp.StatusCode != http.StatusOK {
		return fmt.Errorf("unexpected response code %d", resp.StatusCode)
	}
	err = json.NewDecoder(resp.Body).Decode(&relevantClaims)
	if err != nil {
		return fmt.Errorf("unable to parse userinfo response")
	}
	aliasClaim := relevantClaims[t.aliasClaimName]
	if aliasClaim != expectedAlias {
		return fmt.Errorf("subject in ID token (%s) does not match state (%s)", aliasClaim, expectedAlias)
	}
	return nil
}

func (t OidcValidator) ProcessOidcCallback(r *http.Request) (string, string, error) {
	// Process authorization code
	stateParam := r.URL.Query().Get("state")
	if stateParam == "" {
		return "", "", fmt.Errorf("no state query parameter provided")
	}
	state, ok := t.authorizationsState[stateParam]
	if !ok {
		return "", "", fmt.Errorf("state invalid")
	}
	alias :=  state.alias
	challenge :=  state.challenge
	delete(t.authorizationsState, stateParam)
	code := r.URL.Query().Get("code")
	data := url.Values{}
	data.Set("client_id", t.clientID)
	data.Set("grant_type", "authorization_code")
	data.Set("redirect_uri", t.redirectURI)
	data.Set("code", code)

	req, err := http.NewRequest("POST", t.tokenEndpoint, strings.NewReader(data.Encode()))
	if err != nil {
		return "", "", fmt.Errorf("failed to create token request: %v", err)
	}
	req.SetBasicAuth(t.clientID, t.clientSecret)
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
	client := &http.Client{}
	resp, err := client.Do(req)
	if err != nil {
		return "", "", fmt.Errorf("failed to execute token request: %v", err)
	}
	if resp.StatusCode != http.StatusOK {
		return "", "", fmt.Errorf("unexpected response code %d", resp.StatusCode)
	}
	var tokenResponse OAuthTokenResponse
	err = json.NewDecoder(resp.Body).Decode(&tokenResponse)
	if err != nil {
		return "", "", fmt.Errorf("unable to parse token response: %v", err)
	}
	err = t.ValidateAliasSubject(tokenResponse.AccessToken, alias)
	if err != nil {
		return "", "", fmt.Errorf("unable to validate token: %v", err)
	}
	return alias, challenge, nil
}

func (t OidcValidator) RegistrationStart(topic string, link string, message string, alias string, challenge string) (string, error) {
	state := rand.Text()
	t.authorizationsState[state] = &AuthorizationsState{alias, challenge}
	redirectURI := fmt.Sprintf("%s?response_type=code&redirect_uri=%s&client_id=%s&scope=%s&state=%s", t.authorizationEndpoint, t.redirectURI, t.clientID, t.scope, state)
	return redirectURI, nil
}

func makeOidcValidator(cfg *TaldirConfig, name string, landingPageTpl *template.Template) OidcValidator {
	mainSec := cfg.Ini.Section("taldir")
	baseURL := mainSec.Key("base_url").MustString("")
	// FIXME escape URI?
	redirectURI := fmt.Sprintf("%s/oidc_validator/%s", baseURL, name)
	sec := cfg.Ini.Section("taldir-validator-" + name)
	return OidcValidator{
		name:                  name,
		config:                cfg,
		landingPageTpl:        landingPageTpl,
		clientID:              sec.Key("client_id").MustString(""),
		clientSecret:          sec.Key("client_secret").MustString(""),
		scope:                 sec.Key("scope").MustString("profile"),
		tokenEndpoint:         sec.Key("token_endpoint").MustString(""),
		userinfoEndpoint:      sec.Key("userinfo_endpoint").MustString(""),
		authorizationEndpoint: sec.Key("authorization_endpoint").MustString(""),
		validAliasRegex:       sec.Key("valid_alias_regex").MustString(""),
		aliasClaimName:        sec.Key("alias_claim").MustString("sub"),
		redirectURI:           redirectURI,
		authorizationsState:   make(map[string]*AuthorizationsState, 0),
	}
}
