com.github.lothar.security.acl.jpa.query.AclJpaQuery.java Source code

Java tutorial

Introduction

Here is the source code for com.github.lothar.security.acl.jpa.query.AclJpaQuery.java

Source

/*******************************************************************************
 * Copyright 2002-2016 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package com.github.lothar.security.acl.jpa.query;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.reflect.Field;
import java.lang.reflect.Method;

import javax.persistence.EntityManager;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.aop.framework.Advised;
import org.springframework.aop.framework.ProxyFactoryBean;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.repository.query.PartTreeJpaQuery;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.RepositoryQuery;
import org.springframework.util.ReflectionUtils;

import com.github.lothar.security.acl.jpa.JpaSpecProvider;

public class AclJpaQuery implements RepositoryQuery {

    private Logger logger = LoggerFactory.getLogger(getClass());
    private RepositoryQuery query;
    private Class<?> domainType;
    private EntityManager em;
    private JpaSpecProvider<Object> jpaSpecProvider;
    private Method method;
    private CriteriaQuery<?> cachedCriteriaQuery;
    private AclPredicateTargetSource aclPredicateTargetSource;
    private Root<Object> root;

    public AclJpaQuery(Method method, RepositoryQuery query, Class<?> domainType, EntityManager em,
            JpaSpecProvider<Object> jpaSpecProvider) {
        this.method = method;
        this.query = query;
        this.domainType = domainType;
        this.em = em;
        this.jpaSpecProvider = jpaSpecProvider;
        installAclProxy(query);
    }

    @Override
    public Object execute(Object[] parameters) {
        if (cachedCriteriaQuery == null) {
            return query.execute(parameters);
        }

        // retrieve acl specification
        Specification<Object> aclJpaSpec = jpaSpecProvider.jpaSpecFor(domainType);

        synchronized (cachedCriteriaQuery) {
            installAclSpec(aclJpaSpec);
            try {
                return query.execute(parameters);
            } finally {
                uninstallAclSpec();
            }
        }
    }

    private void installAclSpec(Specification<Object> aclJpaSpec) {
        // force rerender by resetting alias
        root.alias(null);

        // build acl predicate
        Predicate aclPredicate = aclJpaSpec.toPredicate(root, cachedCriteriaQuery, em.getCriteriaBuilder());

        // install acl predicate
        aclPredicateTargetSource.installAcl(aclPredicate);

        logger.debug("ACL Jpa Specification installed for method '{}' and query {}: {}", method, query, aclJpaSpec);
    }

    private void uninstallAclSpec() {
        if (aclPredicateTargetSource != null) {
            aclPredicateTargetSource.uninstallAcl();
            logger.debug("ACL Jpa Specification uninstalled from method '{}' and query {}", method, query);
        }
    }

    private static String getStackTrace(Throwable throwable) {
        StringWriter sw = new StringWriter();
        PrintWriter pw = new PrintWriter(sw, true);
        throwable.printStackTrace(pw);
        return sw.getBuffer().toString();
    }

    private void installAclProxy(RepositoryQuery query) {
        CriteriaQuery<?> criteriaQuery = criteriaQuery();
        if (criteriaQuery == null) {
            logger.warn("Unable to install ACL Jpa Specification for method '" + method + "'and query: " + query
                    + " ; query methods with Pageable/Sort are not (yet) supported");
            return;
        }
        this.cachedCriteriaQuery = criteriaQuery;
        this.root = root(cachedCriteriaQuery);

        try {
            this.aclPredicateTargetSource = installAclPredicateTargetSource();
        } catch (Exception e) {
            logger.warn("Unable to install ACL Jpa Specification for method '" + method + "'and query: " + query
                    + " : " + getStackTrace(e));
        }
    }

    private AclPredicateTargetSource installAclPredicateTargetSource() {
        synchronized (cachedCriteriaQuery) {
            Predicate restriction = cachedCriteriaQuery.getRestriction();

            if (restriction instanceof Advised) {
                Advised advised = (Advised) restriction;
                if (advised.getTargetSource() instanceof AclPredicateTargetSource) {
                    return (AclPredicateTargetSource) advised.getTargetSource();
                }
            }

            AclPredicateTargetSource targetSource = new AclPredicateTargetSource(em.getCriteriaBuilder(),
                    restriction);
            ProxyFactoryBean factoryBean = new ProxyFactoryBean();
            factoryBean.setTargetSource(targetSource);
            factoryBean.setAutodetectInterfaces(true);
            Predicate enhancedPredicate = (Predicate) factoryBean.getObject();
            logger.debug("ACL Jpa Specification target source initialized for criteria {}", cachedCriteriaQuery);

            // install proxy inside criteria
            cachedCriteriaQuery.where(enhancedPredicate);
            return targetSource;
        }
    }

    @Override
    public QueryMethod getQueryMethod() {
        return query.getQueryMethod();
    }

    private CriteriaQuery<?> criteriaQuery() {
        Object queryPreparer = getField(PartTreeJpaQuery.class, query, "query");
        CriteriaQuery<?> criteriaQuery = getField(queryPreparer.getClass(), queryPreparer, "cachedCriteriaQuery");
        return criteriaQuery;
    }

    @SuppressWarnings("unchecked")
    private Root<Object> root(CriteriaQuery<?> criteriaQuery) {
        return (Root<Object>) criteriaQuery.getRoots().iterator().next();
    }

    @SuppressWarnings("unchecked")
    private static <T> T getField(Class<?> type, Object object, String fieldName) {
        Field field = ReflectionUtils.findField(type, fieldName);
        field.setAccessible(true);
        Object property = ReflectionUtils.getField(field, object);
        return (T) property;
    }

}