com.cedac.security.oauth2.provider.token.store.MongoTokenStore.java Source code

Java tutorial

Introduction

Here is the source code for com.cedac.security.oauth2.provider.token.store.MongoTokenStore.java

Source

/*
 * Copyright 2012-2015 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.cedac.security.oauth2.provider.token.store;

import com.mongodb.BasicDBObject;
import com.mongodb.DB;
import com.mongodb.DBCollection;
import com.mongodb.DBCursor;
import com.mongodb.DBObject;
import com.mongodb.Mongo;
import com.mongodb.WriteConcern;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.slf4j.MarkerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.common.util.SerializationUtils;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.TokenStore;
import org.springframework.util.Assert;

import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
 * Mongo operations token store.
 *
 * @author mauro.franceschini@cedac.com
 * @since 1.0.0
 */
public class MongoTokenStore implements TokenStore, InitializingBean {
    private static final Marker TOKEN = MarkerFactory.getDetachedMarker("token");
    private static final Logger LOG = LoggerFactory.getLogger(MongoTokenStore.class);

    private static final WriteConcern DEFAULT_WRITE_CONCERN = WriteConcern.NORMAL;

    private static final String DEFAULT_ACCESS_TOKEN_COLLECTION_NAME = "access_tokens";
    private static final String DEFAULT_REFRESH_TOKEN_COLLECTION_NAME = "refresh_tokens";

    private static final String DEFAULT_TOKEN_ID_FIELD_NAME = "tokenId";
    private static final String DEFAULT_TOKEN_FIELD_NAME = "token";
    private static final String DEFAULT_AUTHENTICATION_ID_FIELD_NAME = "authenticationId";
    private static final String DEFAULT_USERNAME_FIELD_NAME = "username";
    private static final String DEFAULT_CLIENT_ID_FIELD_NAME = "clientId";
    private static final String DEFAULT_AUTHENTICATION_FIELD_NAME = "authentication";
    private static final String DEFAULT_REFRESH_TOKEN_FIELD_NAME = "refreshToken";

    private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();

    private final DB db;

    private String accessTokenCollectionName = DEFAULT_ACCESS_TOKEN_COLLECTION_NAME;
    private String refreshTokenCollectionName = DEFAULT_REFRESH_TOKEN_COLLECTION_NAME;

    private String tokenIdFieldName = DEFAULT_TOKEN_ID_FIELD_NAME;
    private String tokenFieldName = DEFAULT_TOKEN_FIELD_NAME;
    private String authenticationIdFieldName = DEFAULT_AUTHENTICATION_ID_FIELD_NAME;
    private String usernameFieldName = DEFAULT_USERNAME_FIELD_NAME;
    private String clientIdFieldName = DEFAULT_CLIENT_ID_FIELD_NAME;
    private String authenticationFieldName = DEFAULT_AUTHENTICATION_FIELD_NAME;
    private String refreshTokenFieldName = DEFAULT_REFRESH_TOKEN_FIELD_NAME;

    private WriteConcern writeConcern = DEFAULT_WRITE_CONCERN;

    public MongoTokenStore(Mongo mongo, String databaseName) {
        this(mongo.getDB(databaseName));
    }

    public MongoTokenStore(DB db) {
        Assert.notNull(db, "DB is required");
        this.db = db;
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        if (!this.db.collectionExists(accessTokenCollectionName)) {
            LOG.trace(TOKEN, "Creating {} collection", accessTokenCollectionName);

            DBCollection collection = this.db.createCollection(accessTokenCollectionName, new BasicDBObject());
            collection.createIndex(new BasicDBObject(tokenIdFieldName, 1),
                    new BasicDBObject("name", accessTokenCollectionName + "_" + tokenIdFieldName + "_ix")
                            .append("background", 1));
            collection.createIndex(new BasicDBObject(authenticationIdFieldName, 1),
                    new BasicDBObject("name", accessTokenCollectionName + "_" + authenticationFieldName + "_ix")
                            .append("background", 1));

            LOG.debug(TOKEN, "Collection {} successfully created and indexed", accessTokenCollectionName);
        }
        if (!this.db.collectionExists(refreshTokenCollectionName)) {
            LOG.trace(TOKEN, "Creating {} collection", refreshTokenCollectionName);

            DBCollection collection = this.db.createCollection(refreshTokenCollectionName, new BasicDBObject());
            collection.createIndex(new BasicDBObject(tokenIdFieldName, 1),
                    new BasicDBObject("name", refreshTokenCollectionName + "_ix"));

            LOG.debug(TOKEN, "Collection {} successfully created and indexed", accessTokenCollectionName);
        }
    }

    private final DBCollection getAccessTokenCollection() {
        return db.getCollection(accessTokenCollectionName);
    }

    private final DBCollection getRefreshTokenCollection() {
        return db.getCollection(refreshTokenCollectionName);
    }

    public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
        OAuth2AccessToken accessToken = null;

        String key = authenticationKeyGenerator.extractKey(authentication);
        try {
            DBObject query = new BasicDBObject(authenticationIdFieldName, key);
            DBObject projection = new BasicDBObject(tokenFieldName, 1);
            DBObject token = getAccessTokenCollection().findOne(query, projection);
            if (token != null) {
                accessToken = deserializeAccessToken((byte[]) token.get(tokenFieldName));
            } else {
                LOG.debug("Failed to find access token for authentication {}", authentication);
            }
        } catch (IllegalArgumentException e) {
            LOG.error("Could not extract access token for authentication " + authentication, e);
        }

        if (accessToken != null
                && !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
            removeAccessToken(accessToken.getValue());
            // Keep the store consistent (maybe the same user is represented by this authentication but the details have
            // changed)
            storeAccessToken(accessToken, authentication);
        }
        return accessToken;
    }

    public void storeAccessToken(OAuth2AccessToken accessToken, OAuth2Authentication authentication) {
        String refreshToken = null;
        if (accessToken.getRefreshToken() != null) {
            refreshToken = accessToken.getRefreshToken().getValue();
        }

        if (readAccessToken(accessToken.getValue()) != null) {
            removeAccessToken(accessToken.getValue());
        }

        DBObject token = new BasicDBObject();
        token.put(tokenIdFieldName, extractTokenKey(accessToken.getValue()));
        token.put(tokenFieldName, serializeAccessToken(accessToken));
        token.put(authenticationIdFieldName, authenticationKeyGenerator.extractKey(authentication));
        if (!authentication.isClientOnly()) {
            token.put(usernameFieldName, authentication.getName());
        } else {
            token.put(usernameFieldName, null);
        }
        token.put(clientIdFieldName, authentication.getOAuth2Request().getClientId());
        token.put(authenticationFieldName, serializeAuthentication(authentication));
        token.put(refreshTokenFieldName, extractTokenKey(refreshToken));

        getAccessTokenCollection().insert(token, writeConcern);
    }

    public OAuth2AccessToken readAccessToken(String tokenValue) {
        OAuth2AccessToken accessToken = null;

        try {
            DBObject query = new BasicDBObject(tokenIdFieldName, extractTokenKey(tokenValue));
            DBObject projection = new BasicDBObject(tokenFieldName, 1);
            DBObject token = getAccessTokenCollection().findOne(query, projection);
            if (token != null) {
                accessToken = deserializeAccessToken((byte[]) token.get(tokenFieldName));
            } else {
                LOG.info("Failed to find access token for token {}", tokenValue);
            }
        } catch (IllegalArgumentException e) {
            LOG.warn("Failed to deserialize access token for " + tokenValue, e);

            removeAccessToken(tokenValue);
        }

        return accessToken;
    }

    public void removeAccessToken(OAuth2AccessToken token) {
        removeAccessToken(token.getValue());
    }

    public void removeAccessToken(String tokenValue) {
        DBObject query = new BasicDBObject(tokenIdFieldName, extractTokenKey(tokenValue));

        getAccessTokenCollection().remove(query, writeConcern);
    }

    public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
        return readAuthentication(token.getValue());
    }

    public OAuth2Authentication readAuthentication(String token) {
        OAuth2Authentication authentication = null;

        try {
            DBObject query = new BasicDBObject(tokenIdFieldName, extractTokenKey(token));
            DBObject projection = new BasicDBObject(authenticationFieldName, 1);
            DBObject accessToken = getAccessTokenCollection().findOne(query, projection);
            if (accessToken != null) {
                authentication = deserializeAuthentication((byte[]) accessToken.get(authenticationFieldName));
            } else {
                LOG.info("Failed to find access token for token {}", token);
            }
        } catch (IllegalArgumentException e) {
            LOG.warn("Failed to deserialize authentication for " + token, e);

            removeAccessToken(token);
        }

        return authentication;
    }

    public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
        DBObject token = new BasicDBObject();
        token.put(tokenIdFieldName, extractTokenKey(refreshToken.getValue()));
        token.put(tokenFieldName, serializeRefreshToken(refreshToken));
        token.put(authenticationFieldName, serializeAuthentication(authentication));

        getRefreshTokenCollection().insert(token, writeConcern);
    }

    public OAuth2RefreshToken readRefreshToken(String token) {
        OAuth2RefreshToken refreshToken = null;

        try {
            DBObject query = new BasicDBObject(tokenIdFieldName, extractTokenKey(token));
            DBObject projection = new BasicDBObject(tokenFieldName, 1);
            DBObject savedToken = getRefreshTokenCollection().findOne(query, projection);
            if (savedToken != null) {
                refreshToken = deserializeRefreshToken((byte[]) savedToken.get(tokenFieldName));
            } else {
                LOG.info("Failed to find refresh token for token {}", token);
            }
        } catch (IllegalArgumentException e) {
            LOG.warn("Failed to deserialize refresh token for token " + token, e);

            removeRefreshToken(token);
        }

        return refreshToken;
    }

    public void removeRefreshToken(OAuth2RefreshToken token) {
        removeRefreshToken(token.getValue());
    }

    public void removeRefreshToken(String token) {
        DBObject query = new BasicDBObject(tokenIdFieldName, extractTokenKey(token));

        getRefreshTokenCollection().remove(query, writeConcern);
    }

    public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
        return readAuthenticationForRefreshToken(token.getValue());
    }

    public OAuth2Authentication readAuthenticationForRefreshToken(String value) {
        OAuth2Authentication authentication = null;

        try {
            DBObject query = new BasicDBObject(tokenIdFieldName, extractTokenKey(value));
            DBObject projection = new BasicDBObject(authenticationFieldName, 1);
            DBObject savedToken = getRefreshTokenCollection().findOne(query, projection);
            if (savedToken != null) {
                authentication = deserializeAuthentication((byte[]) savedToken.get(authenticationFieldName));
            } else {
                LOG.info("Failed to find access token for token {}", value);
            }
        } catch (IllegalArgumentException e) {
            LOG.warn("Failed to deserialize access token for " + value, e);

            removeRefreshToken(value);
        }

        return authentication;
    }

    public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) {
        removeAccessTokenUsingRefreshToken(refreshToken.getValue());
    }

    public void removeAccessTokenUsingRefreshToken(String refreshToken) {
        DBObject query = new BasicDBObject(refreshTokenFieldName, extractTokenKey(refreshToken));

        getAccessTokenCollection().remove(query, writeConcern);
    }

    public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
        List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();

        DBObject query = new BasicDBObject(clientIdFieldName, clientId);
        DBObject projection = new BasicDBObject(tokenFieldName, 1);
        DBCursor cursor = null;
        try {
            cursor = getAccessTokenCollection().find(query, projection);
            if (cursor.count() > 0) {
                while (cursor.hasNext()) {
                    OAuth2AccessToken token = mapAccessToken(cursor.next());
                    if (token != null) {
                        accessTokens.add(token);
                    }
                }
            } else {
                LOG.info("Failed to find access token for clientId {}", clientId);
            }
            return accessTokens;
        } finally {
            if (cursor != null) {
                cursor.close();
            }
        }
    }

    public Collection<OAuth2AccessToken> findTokensByUserName(String userName) {
        List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();

        DBObject query = new BasicDBObject(usernameFieldName, userName);
        DBObject projection = new BasicDBObject(tokenFieldName, 1);
        DBCursor cursor = null;
        try {
            cursor = getAccessTokenCollection().find(query, projection);
            if (cursor.count() > 0) {
                while (cursor.hasNext()) {
                    OAuth2AccessToken token = mapAccessToken(cursor.next());
                    if (token != null) {
                        accessTokens.add(token);
                    }
                }
            } else {
                LOG.info("Failed to find access token for username {}.", userName);
            }
            return accessTokens;
        } finally {
            if (cursor != null) {
                cursor.close();
            }
        }
    }

    public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
        List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();

        DBObject query = new BasicDBObject(clientIdFieldName, clientId).append(usernameFieldName, userName);
        DBObject projection = new BasicDBObject(tokenFieldName, 1);
        DBCursor cursor = null;
        try {
            cursor = getAccessTokenCollection().find(query, projection);
            if (cursor.count() > 0) {
                while (cursor.hasNext()) {
                    OAuth2AccessToken token = mapAccessToken(cursor.next());
                    if (token != null) {
                        accessTokens.add(token);
                    }
                }
            } else {
                LOG.info("Failed to find access token for clientId {} and username {}.", clientId, userName);
            }
            return accessTokens;
        } finally {
            if (cursor != null) {
                cursor.close();
            }
        }
    }

    protected String extractTokenKey(String value) {
        if (value == null) {
            return null;
        }
        MessageDigest digest;
        try {
            digest = MessageDigest.getInstance("MD5");
        } catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException("MD5 algorithm not available.  Fatal (should be in the JDK).");
        }

        try {
            byte[] bytes = digest.digest(value.getBytes("UTF-8"));
            return String.format("%032x", new BigInteger(1, bytes));
        } catch (UnsupportedEncodingException e) {
            throw new IllegalStateException("UTF-8 encoding not available.  Fatal (should be in the JDK).");
        }
    }

    private final OAuth2AccessToken mapAccessToken(DBObject token) {
        try {
            return deserializeAccessToken((byte[]) token.get(tokenFieldName));
        } catch (IllegalArgumentException e) {
            getAccessTokenCollection().remove(token);
            return null;
        }
    }

    protected byte[] serializeAccessToken(OAuth2AccessToken token) {
        return SerializationUtils.serialize(token);
    }

    protected byte[] serializeRefreshToken(OAuth2RefreshToken token) {
        return SerializationUtils.serialize(token);
    }

    protected byte[] serializeAuthentication(OAuth2Authentication authentication) {
        return SerializationUtils.serialize(authentication);
    }

    protected OAuth2AccessToken deserializeAccessToken(byte[] token) {
        return SerializationUtils.deserialize(token);
    }

    protected OAuth2RefreshToken deserializeRefreshToken(byte[] token) {
        return SerializationUtils.deserialize(token);
    }

    protected OAuth2Authentication deserializeAuthentication(byte[] authentication) {
        return SerializationUtils.deserialize(authentication);
    }

    /*
     * Collection and field name customization.
     */

    public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
        this.authenticationKeyGenerator = authenticationKeyGenerator;
    }

    public void setAccessTokenCollectionName(String accessTokenCollectionName) {
        this.accessTokenCollectionName = accessTokenCollectionName;
    }

    public void setRefreshTokenCollectionName(String refreshTokenCollectionName) {
        this.refreshTokenCollectionName = refreshTokenCollectionName;
    }

    public void setTokenIdFieldName(String tokenIdFieldName) {
        this.tokenIdFieldName = tokenIdFieldName;
    }

    public void setTokenFieldName(String tokenFieldName) {
        this.tokenFieldName = tokenFieldName;
    }

    public void setAuthenticationIdFieldName(String authenticationIdFieldName) {
        this.authenticationIdFieldName = authenticationIdFieldName;
    }

    public void setUsernameFieldName(String usernameFieldName) {
        this.usernameFieldName = usernameFieldName;
    }

    public void setClientIdFieldName(String clientIdFieldName) {
        this.clientIdFieldName = clientIdFieldName;
    }

    public void setAuthenticationFieldName(String authenticationFieldName) {
        this.authenticationFieldName = authenticationFieldName;
    }

    public void setRefreshTokenFieldName(String refreshTokenFieldName) {
        this.refreshTokenFieldName = refreshTokenFieldName;
    }

    public void setWriteConcern(WriteConcern writeConcern) {
        this.writeConcern = writeConcern;
    }
}