Java tutorial
/* * Copyright 2010 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stirling.fix.session; import static stirling.fix.messages.fix42.MsgTypeValue.BUSINESS_MESSAGE_REJECT; import static stirling.fix.messages.fix42.MsgTypeValue.HEARTBEAT; import static stirling.fix.messages.fix42.MsgTypeValue.LOGON; import static stirling.fix.messages.fix42.MsgTypeValue.LOGOUT; import static stirling.fix.messages.fix42.MsgTypeValue.REJECT; import static stirling.fix.messages.fix42.MsgTypeValue.RESEND_REQUEST; import static stirling.fix.messages.fix42.MsgTypeValue.SEQUENCE_RESET; import static stirling.fix.messages.fix42.MsgTypeValue.TEST_REQUEST; import java.util.logging.Logger; import stirling.fix.messages.Value; import stirling.lang.DefaultTimeSource; import stirling.lang.Predicate; import stirling.lang.TimeSource; import org.joda.time.DateTime; import stirling.fix.Config; import stirling.fix.messages.DefaultMessageVisitor; import stirling.fix.messages.FixMessage; import stirling.fix.messages.Heartbeat; import stirling.fix.messages.Logon; import stirling.fix.messages.Logout; import stirling.fix.messages.Message; import stirling.fix.messages.MessageComparator; import stirling.fix.messages.MessageFactory; import stirling.fix.messages.MessageValidator; import stirling.fix.messages.MessageVisitor; import stirling.fix.messages.ParseException; import stirling.fix.messages.Parser; import stirling.fix.messages.Reject; import stirling.fix.messages.ResendRequest; import stirling.fix.messages.SequenceReset; import stirling.fix.messages.TestRequest; import stirling.fix.messages.Validator.ErrorHandler; import stirling.fix.messages.Validator.ErrorLevel; import stirling.fix.messages.fix42.BusinessMessageReject; import stirling.fix.session.store.SessionStore; import stirling.fix.tags.fix42.BeginSeqNo; import stirling.fix.tags.fix42.BusinessRejectReason; import stirling.fix.tags.fix42.EncryptMethod; import stirling.fix.tags.fix42.EndSeqNo; import stirling.fix.tags.fix42.GapFillFlag; import stirling.fix.tags.fix42.HeartBtInt; import stirling.fix.tags.fix42.NewSeqNo; import stirling.fix.tags.fix42.RefMsgType; import stirling.fix.tags.fix42.RefSeqNo; import stirling.fix.tags.fix42.TestReqID; import stirling.fix.tags.fix42.Text; import stirling.fix.tags.fix43.SessionRejectReason; import silvertip.Connection; /** * @author Karim Osman */ public class Session { public static final int MAX_CONSECUTIVE_RESEND_REQUESTS = 10; private static final long DEFAULT_LOGOUT_RESPONSE_TIMEOUT_MSEC = 10000; private static final Logger LOG = Logger.getLogger("Session"); protected MessageQueue<FixMessage> incomingQueue = new MessageQueue<FixMessage>(); protected MessageQueue<Message> outgoingQueue = new MessageQueue<Message>(); protected Sequence outgoingSeq = new Sequence(); protected final HeartBtIntValue heartBtInt; protected final Config config; protected final SessionStore store; protected final MessageFactory messageFactory; protected final MessageComparator messageComparator; private TimeSource timeSource = new DefaultTimeSource(); private long testReqId; private boolean initiatedLogout; private boolean authenticated; private boolean available = true; private DateTime prevTxTime = currentTime(); private DateTime prevRxTime = currentTime(); private boolean waitingForResponseToInitiatedLogout; private DateTime logoutInitiatedAt; public Session(HeartBtIntValue heartBtInt, Config config, SessionStore store, MessageFactory messageFactory, MessageComparator messageComparator) { this.heartBtInt = heartBtInt; this.config = config; this.store = store; this.messageFactory = messageFactory; this.messageComparator = messageComparator; store.load(this); } public void receive(final Connection conn, final FixMessage message, final MessageVisitor visitor) { prevRxTime = currentTime(); try { if (!parseMsgSeqNum(conn, message)) { return; } if (message.getMsgType().equals(SEQUENCE_RESET)) { Parser.parse(messageFactory, message, new Parser.Callback() { @Override public void message(Message sequenceReset) { if (!processSequenceReset(conn, (SequenceReset) sequenceReset)) processMessage(conn, message, visitor); } @Override public void invalidMessage(int msgSeqNum, Value<Integer> reason, String text) { } @Override public void unsupportedMsgType(String msgType, int msgSeqNum) { } @Override public void invalidMsgType(String msgType, int msgSeqNum) { } }); } else { processMessage(conn, message, visitor); } } finally { store.save(this); } } private void processMessage(final Connection conn, FixMessage message, MessageVisitor visitor) { message.setReceiveTime(currentTime()); int expectedMsgSeqNum = incomingQueue.nextSeqNum(); incomingQueue.enqueue(message); if (message.getMsgSeqNum() != expectedMsgSeqNum) { if (!authenticated) { processInSyncMessageQueue(conn, visitor); } processOutOfSyncMessageQueue(conn, message); } else if (!conn.isClosed()) { processInSyncMessageQueue(conn, visitor); } } private boolean processSequenceReset(Connection conn, SequenceReset message) { if (message.getBoolean(GapFillFlag.Tag())) { return processSequenceResetGapFill(conn, message); } return processSequenceResetReset(conn, message); } private boolean processSequenceResetGapFill(Connection conn, SequenceReset message) { if (message.getMsgSeqNum() < incomingQueue.nextSeqNum()) { if (!message.getPossDupFlag()) { String text = "MsgSeqNum too low, expecting " + incomingQueue.nextSeqNum() + " but received " + message.getMsgSeqNum(); getLogger().severe(text); terminate(conn, text); } return true; } return false; } private boolean processSequenceResetReset(Connection conn, SequenceReset message) { if (message.getNewSeqNo() == message.getMsgSeqNum()) { getLogger().warning("NewSeqNo(36)=" + message.getNewSeqNo() + " is equal to expected MsgSeqNum(34)=" + message.getMsgSeqNum()); } else if (message.getNewSeqNo() < message.getMsgSeqNum()) { String text = "Value is incorrect (out of range) for this tag, NewSeqNo(36)=" + message.getNewSeqNo(); getLogger().warning(text); sessionReject(conn, message.getMsgSeqNum(), SessionRejectReason.InvalidValue(), text); } else { incomingQueue.reset(message.getNewSeqNo()); } return true; } private void processOutOfSyncMessageQueue(final Connection conn, final FixMessage message) { if (message.getMsgSeqNum() > incomingQueue.nextSeqNum()) { if (incomingQueue.getOutOfOrderCount() > MAX_CONSECUTIVE_RESEND_REQUESTS) { terminate(conn, "Maximum resend requests (" + MAX_CONSECUTIVE_RESEND_REQUESTS + ") exceeded"); return; } sendResendRequest(conn, incomingQueue.nextSeqNum(), 0); } else { Parser.parse(messageFactory, message, new Parser.Callback() { @Override public void message(Message msg) { if (!msg.getPossDupFlag()) { terminateOnMsgSeqNumTooLow(conn, message); } } @Override public void invalidMessage(int msgSeqNum, Value<Integer> reason, String text) { terminateOnMsgSeqNumTooLow(conn, message); } @Override public void unsupportedMsgType(String msgType, int msgSeqNum) { terminateOnMsgSeqNumTooLow(conn, message); } @Override public void invalidMsgType(String msgType, int msgSeqNum) { terminateOnMsgSeqNumTooLow(conn, message); } }); } } private void terminateOnMsgSeqNumTooLow(Connection conn, FixMessage message) { String text = "MsgSeqNum too low, expecting " + incomingQueue.nextSeqNum() + " but received " + message.getMsgSeqNum(); getLogger().severe(text); terminate(conn, text); } private void processInSyncMessageQueue(final Connection conn, final MessageVisitor visitor) { while (!incomingQueue.isEmpty()) { Parser.parse(messageFactory, incomingQueue.dequeue(), new Parser.Callback() { @Override public void message(Message message) { if (validate(conn, message)) process(conn, message, visitor); else incomingQueue.skip(message.getMsgSeqNum()); } @Override public void invalidMessage(int msgSeqNum, Value<Integer> reason, String text) { incomingQueue.skip(msgSeqNum); if (authenticated) { getLogger().severe(text); sessionReject(conn, msgSeqNum, reason, text); } else { getLogger().severe(text); logout(conn); } } @Override public void unsupportedMsgType(String msgType, int msgSeqNum) { getLogger().warning("MsgType(35): Unknown message type: " + msgType); incomingQueue.skip(msgSeqNum); businessReject(conn, msgType, msgSeqNum, BusinessRejectReason.UnknownMessageType(), "MsgType(35): Unknown message type: " + msgType); } @Override public void invalidMsgType(String msgType, int msgSeqNum) { getLogger().warning("MsgType(35): Invalid message type: " + msgType); incomingQueue.skip(msgSeqNum); sessionReject(conn, msgSeqNum, SessionRejectReason.InvalidMsgType(), "MsgType(35): Invalid message type: " + msgType); } }); } } private boolean parseMsgSeqNum(Connection conn, FixMessage message) { try { int msgSeqNum = Parser.parseMsgSeqNum(message); message.setMsgSeqNum(msgSeqNum); return true; } catch (ParseException e) { getLogger().severe(e.getMessage()); terminate(conn, e.getMessage()); return false; } } private boolean validate(final Connection conn, final Message message) { return MessageValidator.validate(this, message, new ErrorHandler() { @Override public void sessionReject(Value<Integer> reason, String text, ErrorLevel level, boolean terminate) { logError(text, level); Session.this.sessionReject(conn, message, reason, text); if (terminate) Session.this.terminate(conn, text); } @Override public void businessReject(Value<Integer> reason, String text, ErrorLevel level) { logError(text, level); Session.this.businessReject(conn, message.getMsgType(), message.getMsgSeqNum(), reason, text); } @Override public void terminate(String text) { logError(text, ErrorLevel.ERROR); Session.this.terminate(conn, text); } private void logError(String text, ErrorLevel level) { switch (level) { case WARNING: getLogger().warning(text); break; case ERROR: getLogger().severe(text); break; } } }); } private void sendGapFill(final Connection conn, int beginSeqNo, int endSeqNo) { SequenceReset gapFillMsg = (SequenceReset) messageFactory.create(SEQUENCE_RESET); gapFillMsg.setBoolean(GapFillFlag.Tag(), true); gapFillMsg.setPossDupFlag(true); gapFillMsg.setMsgSeqNum(beginSeqNo); gapFillMsg.setInteger(NewSeqNo.Tag(), endSeqNo); send(conn, gapFillMsg, false, false); } private void resendRange(final Connection conn, int beginSeqNo, int endSeqNo) { if (endSeqNo == 0) { endSeqNo = outgoingSeq.peek() - 1; } int nextSeqNo = beginSeqNo; for (Message msg : store.getOutgoingMessages(this, beginSeqNo, endSeqNo)) { if (msg.isAdminMessage() && !msg.getMsgType().equals(REJECT)) { continue; } int msgSeqNum = msg.getMsgSeqNum(); if (msgSeqNum > nextSeqNo) { sendGapFill(conn, nextSeqNo, msgSeqNum); } msg.setPossDupFlag(true); send(conn, msg, false, false); nextSeqNo = msgSeqNum + 1; } if (nextSeqNo <= endSeqNo) { sendGapFill(conn, nextSeqNo, endSeqNo + 1); } } private void process(final Connection conn, Message message, final MessageVisitor visitor) { if (authenticated) { message.apply(new DefaultMessageVisitor() { @Override public void visit(TestRequest message) { store.saveIncomingMessage(Session.this, message); incomingQueue.skip(message.getMsgSeqNum()); Heartbeat heartbeat = (Heartbeat) messageFactory.create(HEARTBEAT); heartbeat.setString(TestReqID.Tag(), message.getString(TestReqID.Tag())); send(conn, heartbeat); } @Override public void visit(ResendRequest message) { store.saveIncomingMessage(Session.this, message); if (outgoingQueue.isEmpty()) { incomingQueue.skip(message.getMsgSeqNum()); int beginSeqNo = message.getInteger(BeginSeqNo.Tag()); int endSeqNo = message.getInteger(EndSeqNo.Tag()); resendRange(conn, beginSeqNo, endSeqNo); } else { while (!outgoingQueue.isEmpty()) { Message msg = outgoingQueue.dequeue(); msg.setPossDupFlag(true); send(conn, msg, false, false); } } } @Override public void visit(SequenceReset message) { store.saveIncomingMessage(Session.this, message); processSeqReset(conn, message); } @Override public void visit(Logout message) { store.saveIncomingMessage(Session.this, message); incomingQueue.skip(message.getMsgSeqNum()); if (!initiatedLogout) send(conn, messageFactory.create(LOGOUT)); else waitingForResponseToInitiatedLogout = false; conn.close(); } @Override public void defaultAction(Message message) { if (message.getPossResend() && store.isDuplicate(Session.this, message)) { store.saveIncomingMessage(Session.this, message); } else { store.saveIncomingMessage(Session.this, message); message.apply(visitor); } } }); } else { message.apply(new DefaultMessageVisitor() { @Override public void visit(Logon message) { authenticated = true; store.saveIncomingMessage(Session.this, message); message.apply(visitor); } @Override public void defaultAction(Message message) { getLogger().severe("first message is not a logon"); logout(conn); } }); } } private void sendResendRequest(Connection conn, int beginSeqNo, int endSeqNo) { ResendRequest resendReq = (ResendRequest) messageFactory.create(RESEND_REQUEST); resendReq.setInteger(BeginSeqNo.Tag(), beginSeqNo); resendReq.setInteger(EndSeqNo.Tag(), endSeqNo); send(conn, resendReq); } private void processSeqReset(Connection conn, SequenceReset message) { int newSeqNo = message.getInteger(NewSeqNo.Tag()); if (newSeqNo <= message.getMsgSeqNum()) { sessionReject(conn, message.getMsgSeqNum(), SessionRejectReason.InvalidValue(), "Attempt to lower sequence number, invalid value NewSeqNo(36)=" + newSeqNo); } else { incomingQueue.reset(newSeqNo); } } private void sessionReject(Connection conn, Message message, Value<Integer> reason, String text) { sessionReject(conn, message.getMsgSeqNum(), reason, text); } private void sessionReject(Connection conn, int msgSeqNum, Value<Integer> reason, String text) { Reject reject = (Reject) messageFactory.create(REJECT); reject.setInteger(RefSeqNo.Tag(), msgSeqNum); reject.setEnum(SessionRejectReason.Tag(), reason); reject.setString(Text.Tag(), text); send(conn, reject); } private void businessReject(Connection conn, String msgType, int msgSeqNum, Value<Integer> reason, String text) { BusinessMessageReject reject = (BusinessMessageReject) messageFactory.create(BUSINESS_MESSAGE_REJECT); reject.setInteger(RefSeqNo.Tag(), msgSeqNum); reject.setString(RefMsgType.Tag(), msgType); reject.setEnum(BusinessRejectReason.Tag(), reason); reject.setString(Text.Tag(), text); send(conn, reject); } private void terminate(Connection conn, String text) { Logout logout = (Logout) messageFactory.create(LOGOUT); logout.setString(Text.Tag(), text); send(conn, logout); conn.close(); } public void logon(Connection conn) { authenticated = initiatedLogout = false; Logon message = (Logon) messageFactory.create(LOGON); message.setInteger(HeartBtInt.Tag(), heartBtInt.getSeconds()); message.setEnum(EncryptMethod.Tag(), EncryptMethod.None()); message.setMsgSeqNum(outgoingSeq.next()); send(conn, message, false, true); } public void logon(Connection conn, Logon logonMessage) { authenticated = initiatedLogout = false; logonMessage.setMsgSeqNum(outgoingSeq.next()); send(conn, logonMessage, false, true); } public void logout(final Connection conn) { send(conn, messageFactory.create(LOGOUT)); initiatedLogout = true; logoutInitiatedAt = currentTime(); waitingForResponseToInitiatedLogout = true; } public void sequenceReset(Connection conn, Sequence seq) { SequenceReset message = (SequenceReset) messageFactory.create(SEQUENCE_RESET); message.setMsgSeqNum(seq.peek()); message.setInteger(NewSeqNo.Tag(), seq.next()); message.setBoolean(GapFillFlag.Tag(), false); send(conn, message, false, true); setOutgoingSeq(seq); } public void keepAlive(Connection conn) { if (isTimedOut(prevTxTime, heartBtInt.heartbeat().delayMsec())) { heartbeat(conn); prevTxTime = currentTime(); } if (isTimedOut(prevRxTime, heartBtInt.testRequest().delayMsec())) { testRequest(conn); prevRxTime = currentTime(); } } public void heartbeat(Connection conn) { send(conn, messageFactory.create(HEARTBEAT)); } private void testRequest(Connection conn) { TestRequest req = (TestRequest) messageFactory.create(TEST_REQUEST); req.setString(TestReqID.Tag(), Long.toString(++testReqId)); send(conn, req); } public void send(Connection conn, Message message) { send(conn, message, true, true); } public void send(Connection conn, Message message, boolean queue, boolean save) { message.setHeaderConfig(config); if (!queue) { message.setSendingTime(currentTime()); conn.send(FixMessage.fromString(message.format())); prevTxTime = currentTime(); } else { message.setMsgSeqNum(outgoingSeq.next()); outgoingQueue.enqueue(message); if (conn != null && !conn.isClosed()) { while (!outgoingQueue.isEmpty()) { Message msg = outgoingQueue.dequeue(); msg.setSendingTime(currentTime()); conn.send(FixMessage.fromString(msg.format())); prevTxTime = currentTime(); } } } if (save) { store.saveOutgoingMessage(this, message); } } public void processInitiatedLogout(Connection conn) { if (waitingForResponseToInitiatedLogout && isTimedOut(logoutInitiatedAt, getLogoutResponseTimeoutMsec())) { getLogger().warning("Response to logout not received in " + getLogoutResponseTimeoutMsec() / 1000 + " second(s), disconnecting"); waitingForResponseToInitiatedLogout = false; conn.close(); } } private boolean isTimedOut(DateTime dateTime, long timeoutMsec) { DateTime now = currentTime(); DateTime timeOutAt = dateTime.plusMillis((int) timeoutMsec); return now.isAfter(timeOutAt); } public Config getConfig() { return config; } public Sequence getOutgoingSeq() { return outgoingSeq; } public void setOutgoingSeq(Sequence seq) { outgoingSeq = seq; } public Sequence getIncomingSeq() { Sequence seq = new Sequence(); seq.reset(incomingQueue.nextSeqNum()); return seq; } public void setIncomingSeq(Sequence seq) { incomingQueue.reset(seq.peek()); } public DateTime currentTime() { return timeSource.currentTime(); } public boolean isAuthenticated() { return authenticated; } public void setAvailable(boolean available) { this.available = available; } public boolean isAvailable() { return available; } public MessageFactory getMessageFactory() { return messageFactory; } public MessageComparator getMessageComparator() { return messageComparator; } protected long getLogoutResponseTimeoutMsec() { return DEFAULT_LOGOUT_RESPONSE_TIMEOUT_MSEC; } protected Logger getLogger() { return LOG; } protected boolean checkSeqResetSeqNum() { return true; } }