com.baomidou.mybatisplus.plugins.OptimisticLockerInterceptor.java Source code

Java tutorial

Introduction

Here is the source code for com.baomidou.mybatisplus.plugins.OptimisticLockerInterceptor.java

Source

/**
 * Copyright (c) 2011-2014, hubin (jobob@qq.com).
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.baomidou.mybatisplus.plugins;

import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.Timestamp;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.exceptions.ExceptionFactory;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.TypeException;
import org.apache.ibatis.type.UnknownTypeHandler;

import com.baomidou.mybatisplus.annotations.TableField;
import com.baomidou.mybatisplus.annotations.Version;
import com.baomidou.mybatisplus.mapper.EntityWrapper;
import com.baomidou.mybatisplus.toolkit.PluginUtils;
import com.baomidou.mybatisplus.toolkit.StringUtils;

import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.update.Update;

/**
 * <p>
 * MyBatis???
 * </p>
 * 
 * <pre>
 * ?update user set name = ?, password = ? where id = ?
 * ?update user set name = ?, password = ?, version = version+1 where id = ? and version = ?
 * version{@link Version}
 * sql???version,??version
 * ?,int Integer, long Long, Date,Timestamp
 * ?,versionHandlers,?
 * </pre>
 *
 * @author TaoYu ?
 * @since 2017-04-08
 */
@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class, Integer.class }) })
public final class OptimisticLockerInterceptor implements Interceptor {

    /**
     * ?version?
     */
    private static final Map<Class<?>, LockerCache> versionCache = new ConcurrentHashMap<>();

    /**
     * ?version?
     */
    private static final Map<Type, VersionHandler<?>> typeHandlers = new HashMap<>();

    private static final Expression RIGHT_EXPRESSION = new Column("?");

    static {
        IntegerTypeHandler integerTypeHandler = new IntegerTypeHandler();
        typeHandlers.put(int.class, integerTypeHandler);
        typeHandlers.put(Integer.class, integerTypeHandler);

        LongTypeHandler longTypeHandler = new LongTypeHandler();
        typeHandlers.put(long.class, longTypeHandler);
        typeHandlers.put(Long.class, longTypeHandler);

        typeHandlers.put(Date.class, new DateTypeHandler());
        typeHandlers.put(Timestamp.class, new TimestampTypeHandler());
    }

    public Object intercept(Invocation invocation) throws Exception {
        StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        // ?UPDATE?
        MappedStatement ms = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        if (!ms.getSqlCommandType().equals(SqlCommandType.UPDATE)) {
            return invocation.proceed();
        }
        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        // ?,?version?
        Class<?> parameterClass = ms.getParameterMap().getType();
        LockerCache lockerCache = versionCache.get(parameterClass);
        if (lockerCache != null) {
            if (lockerCache.lock) {
                processChangeSql(ms, boundSql, lockerCache);
            }
        } else {
            Field versionField = getVersionField(parameterClass);
            if (versionField != null) {
                Class<?> fieldType = versionField.getType();
                if (!typeHandlers.containsKey(fieldType)) {
                    throw new TypeException(
                            "????" + fieldType.getName() + ",");
                }
                final TableField tableField = versionField.getAnnotation(TableField.class);
                String versionColumn = versionField.getName();
                if (tableField != null) {
                    versionColumn = tableField.value();
                }
                LockerCache lc = new LockerCache(true, versionColumn, versionField, typeHandlers.get(fieldType));
                versionCache.put(parameterClass, lc);
                processChangeSql(ms, boundSql, lc);
            } else {
                versionCache.put(parameterClass, LockerCache.INSTANCE);
            }
        }
        return invocation.proceed();

    }

    private Field getVersionField(Class<?> parameterClass) {
        if (parameterClass != Object.class) {
            for (Field field : parameterClass.getDeclaredFields()) {
                if (field.isAnnotationPresent(Version.class)) {
                    field.setAccessible(true);
                    return field;
                }
            }
            return getVersionField(parameterClass.getSuperclass());
        }
        return null;

    }

    private void processChangeSql(MappedStatement ms, BoundSql boundSql, LockerCache lockerCache) throws Exception {
        Object parameterObject = boundSql.getParameterObject();
        if (parameterObject instanceof ParamMap) {
            ParamMap<?> paramMap = (ParamMap<?>) parameterObject;
            parameterObject = paramMap.get("et");
            EntityWrapper<?> entityWrapper = (EntityWrapper<?>) paramMap.get("ew");
            if (entityWrapper != null) {
                Object entity = entityWrapper.getEntity();
                if (entity != null && lockerCache.field.get(entity) == null) {
                    changSql(ms, boundSql, parameterObject, lockerCache);
                }
            }
        } else {
            changSql(ms, boundSql, parameterObject, lockerCache);
        }
    }

    @SuppressWarnings("unchecked")
    private void changSql(MappedStatement ms, BoundSql boundSql, Object parameterObject, LockerCache lockerCache)
            throws Exception {
        Field versionField = lockerCache.field;
        String versionColumn = lockerCache.column;
        final Object versionValue = versionField.get(parameterObject);
        if (versionValue != null) {// ???version,?
            Configuration configuration = ms.getConfiguration();
            // 
            lockerCache.versionHandler.plusVersion(parameterObject, versionField, versionValue);
            // ?where?,?
            Update jsqlSql = (Update) CCJSqlParserUtil.parse(boundSql.getSql());
            BinaryExpression expression = (BinaryExpression) jsqlSql.getWhere();
            if (expression != null && !expression.toString().contains(versionColumn)) {
                EqualsTo equalsTo = new EqualsTo();
                equalsTo.setLeftExpression(new Column(versionColumn));
                equalsTo.setRightExpression(RIGHT_EXPRESSION);
                jsqlSql.setWhere(new AndExpression(equalsTo, expression));
                List<ParameterMapping> parameterMappings = new LinkedList<>(boundSql.getParameterMappings());
                parameterMappings.add(jsqlSql.getExpressions().size(), getVersionMappingInstance(configuration));
                MetaObject boundSqlMeta = configuration.newMetaObject(boundSql);
                boundSqlMeta.setValue("sql", jsqlSql.toString());
                boundSqlMeta.setValue("parameterMappings", parameterMappings);
            }
            // ?
            boundSql.setAdditionalParameter("originVersionValue", versionValue);
        }
    }

    private volatile ParameterMapping parameterMapping;

    private ParameterMapping getVersionMappingInstance(Configuration configuration) {
        if (parameterMapping == null) {
            synchronized (OptimisticLockerInterceptor.class) {
                if (parameterMapping == null) {
                    parameterMapping = new ParameterMapping.Builder(configuration, "originVersionValue",
                            new UnknownTypeHandler(configuration.getTypeHandlerRegistry())).build();
                }
            }
        }
        return parameterMapping;
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {
        String versionHandlers = properties.getProperty("versionHandlers");
        if (StringUtils.isNotEmpty(versionHandlers)) {
            for (String handlerClazz : versionHandlers.split(",")) {
                try {
                    registerHandler(Class.forName(handlerClazz));
                } catch (Exception e) {
                    throw ExceptionFactory.wrapException("????", e);
                }
            }
        }
    }

    /**
     * ?
     */
    private static void registerHandler(Class<?> versionHandlerClazz) throws Exception {
        ParameterizedType parameterizedType = (ParameterizedType) versionHandlerClazz.getGenericInterfaces()[0];
        Object versionInstance = versionHandlerClazz.newInstance();
        if (!(versionInstance instanceof VersionHandler)) {
            throw new TypeException("?VersionHandler,?");
        } else {
            Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
            if (actualTypeArguments.length == 0) {
                throw new IllegalArgumentException("?");
            } else if (Object.class.equals(actualTypeArguments[0])) {
                throw new IllegalArgumentException("??Object");
            } else {
                typeHandlers.put(actualTypeArguments[0], (VersionHandler<?>) versionInstance);
            }
        }
    }

    // *****************************?*****************************
    private static class IntegerTypeHandler implements VersionHandler<Integer> {

        public void plusVersion(Object paramObj, Field field, Integer versionValue) throws Exception {
            field.set(paramObj, versionValue + 1);
        }
    }

    private static class LongTypeHandler implements VersionHandler<Long> {

        public void plusVersion(Object paramObj, Field field, Long versionValue) throws Exception {
            field.set(paramObj, versionValue + 1);
        }
    }

    // ***************************** ?*****************************
    private static class DateTypeHandler implements VersionHandler<Date> {

        public void plusVersion(Object paramObj, Field field, Date versionValue) throws Exception {
            field.set(paramObj, new Date());
        }
    }

    private static class TimestampTypeHandler implements VersionHandler<Timestamp> {

        public void plusVersion(Object paramObj, Field field, Timestamp versionValue) throws Exception {
            field.set(paramObj, new Timestamp(new Date().getTime()));
        }
    }

    /**
     * 
     */
    @SuppressWarnings("rawtypes")
    private static class LockerCache {

        public static final LockerCache INSTANCE = new LockerCache();

        private boolean lock;
        private String column;
        private Field field;
        private VersionHandler versionHandler;

        public LockerCache() {
        }

        LockerCache(Boolean lock, String column, Field field, VersionHandler versionHandler) {
            this.lock = lock;
            this.column = column;
            this.field = field;
            this.versionHandler = versionHandler;
        }
    }

}