JWTAuthentication.java [src/csip] Revision:   Date:
/*
 * $Id$
 *
 * This file is part of the Cloud Services Integration Platform (CSIP),
 * a Model-as-a-Service framework, API and application suite.
 *
 * 2012-2022, Olaf David and others, OMSLab, Colorado State University.
 *
 * OMSLab licenses this file to you under the MIT license.
 * See the LICENSE file in the project root for more information.
 */
package csip;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkException;
import com.auth0.jwk.JwkProvider;
import com.auth0.jwk.UrlJwkProvider;
import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.interfaces.DecodedJWT;
import java.security.interfaces.RSAPublicKey;
import java.util.Calendar;

/**
 * JSON Web token authentication.
 *
 * <pre>
 * enable JWT authentication in config (as json):
 *   ...
 *  "csip.token.authentication" : "jwt",
 *  "csip.jwk.provider.url": "http://jwk.server.com:4444",
 *  "csip.jwt.alg": "RSA256"
 *   ...
 * </pre>
 *
 * @author od
 */
class JWTAuthentication implements TokenAuthentication {

  JwkProvider provider;
  String alg = Config.getString("csip.jwt.alg", "RSA256").toLowerCase();


  JWTAuthentication(String jwkUrl) {
    if (jwkUrl == null) 
      throw new RuntimeException("Missing configuration:'csip.jwk.provider.url'");
    
    provider = new UrlJwkProvider(jwkUrl);
  }


  /**
   * validate a token as JWT.
   *
   * @param token the JWT
   * @throws SecurityException if token is missing, validation fails against the
   * public key or the JWT is expired.
   */
  @Override
  public void validate(String token) throws SecurityException {
    try {
      // check token
      if (token == null || token.isEmpty()) 
        throw new SecurityException("JWT missing.");
      
      try {
        DecodedJWT jwt = JWT.decode(token);
        Jwk jwk = provider.get(jwt.getKeyId());
        Algorithm algorithm;
        switch (alg) {
          case "rsa256":
          case "rs256":
            algorithm = Algorithm.RSA256((RSAPublicKey) jwk.getPublicKey(), null);
            break;
          case "rsa384":
          case "rs384":
            algorithm = Algorithm.RSA384((RSAPublicKey) jwk.getPublicKey(), null);
            break;
          case "rsa512":
          case "rs512":
            algorithm = Algorithm.RSA512((RSAPublicKey) jwk.getPublicKey(), null);
            break;
          default:
            throw new SecurityException("Invalid Algorithm: " + alg);
        }
        // verify the signature
        algorithm.verify(jwt);

        // check for expiration
        if (jwt.getExpiresAt().before(Calendar.getInstance().getTime())) 
          throw new SecurityException("JWT expired.");
        
      } catch (JWTVerificationException E) {
        throw new SecurityException("Signature verification error.", E);
      }
    } catch (JwkException E) {
      throw new SecurityException("JWK exception.", E);
    }
  }
}