Java tutorial
/* * Copyright 2015 Kevin Herron * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.digitalpetri.opcua.stack.server.tcp; import java.net.URI; import java.net.URISyntaxException; import java.security.cert.CertificateEncodingException; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import com.digitalpetri.opcua.stack.core.Stack; import com.digitalpetri.opcua.stack.core.StatusCodes; import com.digitalpetri.opcua.stack.core.UaException; import com.digitalpetri.opcua.stack.core.application.CertificateManager; import com.digitalpetri.opcua.stack.core.application.CertificateValidator; import com.digitalpetri.opcua.stack.core.application.UaStackServer; import com.digitalpetri.opcua.stack.core.application.services.AttributeServiceSet; import com.digitalpetri.opcua.stack.core.application.services.DiscoveryServiceSet; import com.digitalpetri.opcua.stack.core.application.services.MethodServiceSet; import com.digitalpetri.opcua.stack.core.application.services.MonitoredItemServiceSet; import com.digitalpetri.opcua.stack.core.application.services.NodeManagementServiceSet; import com.digitalpetri.opcua.stack.core.application.services.QueryServiceSet; import com.digitalpetri.opcua.stack.core.application.services.ServiceRequest; import com.digitalpetri.opcua.stack.core.application.services.ServiceRequestHandler; import com.digitalpetri.opcua.stack.core.application.services.ServiceResponse; import com.digitalpetri.opcua.stack.core.application.services.SessionServiceSet; import com.digitalpetri.opcua.stack.core.application.services.SubscriptionServiceSet; import com.digitalpetri.opcua.stack.core.application.services.TestServiceSet; import com.digitalpetri.opcua.stack.core.application.services.ViewServiceSet; import com.digitalpetri.opcua.stack.core.channel.ChannelConfig; import com.digitalpetri.opcua.stack.core.channel.ServerSecureChannel; import com.digitalpetri.opcua.stack.core.security.SecurityPolicy; import com.digitalpetri.opcua.stack.core.serialization.UaRequestMessage; import com.digitalpetri.opcua.stack.core.serialization.UaResponseMessage; import com.digitalpetri.opcua.stack.core.types.builtin.ByteString; import com.digitalpetri.opcua.stack.core.types.enumerated.ApplicationType; import com.digitalpetri.opcua.stack.core.types.enumerated.MessageSecurityMode; import com.digitalpetri.opcua.stack.core.types.structured.ApplicationDescription; import com.digitalpetri.opcua.stack.core.types.structured.EndpointDescription; import com.digitalpetri.opcua.stack.core.types.structured.FindServersRequest; import com.digitalpetri.opcua.stack.core.types.structured.FindServersResponse; import com.digitalpetri.opcua.stack.core.types.structured.GetEndpointsRequest; import com.digitalpetri.opcua.stack.core.types.structured.GetEndpointsResponse; import com.digitalpetri.opcua.stack.core.types.structured.SignedSoftwareCertificate; import com.digitalpetri.opcua.stack.core.types.structured.UserTokenPolicy; import com.digitalpetri.opcua.stack.server.Endpoint; import com.digitalpetri.opcua.stack.server.config.UaTcpStackServerConfig; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Multimaps; import com.google.common.collect.Sets; import io.netty.channel.Channel; import io.netty.util.AttributeKey; import io.netty.util.HashedWheelTimer; import io.netty.util.Timeout; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import static com.digitalpetri.opcua.stack.core.types.builtin.unsigned.Unsigned.ubyte; import static com.digitalpetri.opcua.stack.core.util.ConversionUtil.a; import static com.google.common.collect.Lists.newArrayList; import static java.util.stream.Collectors.toList; public class UaTcpStackServer implements UaStackServer { /** * The {@link AttributeKey} that maps to the {@link Channel} bound to a {@link ServerSecureChannel}. */ public static final AttributeKey<Channel> BoundChannelKey = AttributeKey.valueOf("bound-channel"); private final Logger logger = LoggerFactory.getLogger(getClass()); private final AtomicLong channelIds = new AtomicLong(); private final AtomicLong tokenIds = new AtomicLong(); private final Map<Class<? extends UaRequestMessage>, ServiceRequestHandler<UaRequestMessage, UaResponseMessage>> handlers = Maps .newConcurrentMap(); private final Map<Long, ServerSecureChannel> secureChannels = Maps.newConcurrentMap(); private final ListMultimap<Long, ServiceResponse> responseQueues = Multimaps .synchronizedListMultimap(ArrayListMultimap.create()); private final List<Endpoint> endpoints = Lists.newCopyOnWriteArrayList(); private final Set<String> discoveryUrls = Sets.newConcurrentHashSet(); private final HashedWheelTimer wheelTimer = Stack.sharedWheelTimer(); private final Map<Long, Timeout> timeouts = Maps.newConcurrentMap(); private final UaTcpStackServerConfig config; public UaTcpStackServer(UaTcpStackServerConfig config) { this.config = config; addServiceSet(new DefaultDiscoveryServiceSet()); addServiceSet(new AttributeServiceSet() { }); addServiceSet(new MethodServiceSet() { }); addServiceSet(new MonitoredItemServiceSet() { }); addServiceSet(new NodeManagementServiceSet() { }); addServiceSet(new QueryServiceSet() { }); addServiceSet(new SessionServiceSet() { }); addServiceSet(new SubscriptionServiceSet() { }); addServiceSet(new TestServiceSet() { }); addServiceSet(new ViewServiceSet() { }); } public UaTcpStackServerConfig getConfig() { return config; } @Override public void startup() { for (Endpoint endpoint : endpoints) { try { URI endpointUri = endpoint.getEndpointUri(); String bindAddress = endpoint.getBindAddress().orElse(endpointUri.getHost()); SocketServer socketServer = SocketServer.boundTo(bindAddress, endpointUri.getPort()); socketServer.setStrictEndpointUrlsEnabled(config.isStrictEndpointUrlsEnabled()); logger.info("{} bound to {} [{}/{}]", endpoint.getEndpointUri(), socketServer.getLocalAddress(), endpoint.getSecurityPolicy(), endpoint.getMessageSecurity()); addDiscoveryUrl(endpointUri); socketServer.addServer(this); } catch (Exception e) { logger.error("Error binding {}: {}.", endpoint, e.getMessage(), e); } } } private void addDiscoveryUrl(URI endpointUri) { String serverName = config.getServerName(); StringBuilder discoveryUrl = new StringBuilder(); discoveryUrl.append("opc.tcp://").append(endpointUri.getHost()).append(":").append(endpointUri.getPort()); if (!serverName.isEmpty()) { discoveryUrl.append("/").append(serverName); } discoveryUrls.add(discoveryUrl.toString()); } @Override public void shutdown() { for (Endpoint endpoint : endpoints) { URI endpointUri = endpoint.getEndpointUri(); String address = endpoint.getBindAddress().orElse(endpointUri.getHost()); try { SocketServer socketServer = SocketServer.boundTo(address, endpointUri.getPort()); socketServer.removeServer(this); } catch (Exception e) { logger.error("Error getting SocketServer for {}: {}.", endpoint, e.getMessage(), e); } } List<ServerSecureChannel> copy = newArrayList(secureChannels.values()); copy.forEach(this::closeSecureChannel); } public void receiveRequest(ServiceRequest<UaRequestMessage, UaResponseMessage> serviceRequest) { logger.trace("Received {} on {}.", serviceRequest, serviceRequest.getSecureChannel()); serviceRequest.getFuture().whenComplete((response, throwable) -> { long requestId = serviceRequest.getRequestId(); ServiceResponse serviceResponse = response != null ? new ServiceResponse(serviceRequest.getRequest(), requestId, response) : new ServiceResponse(serviceRequest.getRequest(), requestId, serviceRequest.createServiceFault(throwable)); ServerSecureChannel secureChannel = serviceRequest.getSecureChannel(); boolean secureChannelValid = secureChannels.containsKey(secureChannel.getChannelId()); if (secureChannelValid) { Channel channel = secureChannel.attr(BoundChannelKey).get(); if (channel != null) { if (serviceResponse.isServiceFault()) { logger.debug("Sending {} on {}.", serviceResponse, secureChannel); } else { logger.trace("Sending {} on {}.", serviceResponse, secureChannel); } channel.writeAndFlush(serviceResponse, channel.voidPromise()); } else { logger.trace("Queueing {} for unbound {}.", serviceResponse, secureChannel); responseQueues.put(secureChannel.getChannelId(), serviceResponse); } } }); Class<? extends UaRequestMessage> requestClass = serviceRequest.getRequest().getClass(); ServiceRequestHandler<UaRequestMessage, UaResponseMessage> handler = handlers.get(requestClass); try { if (handler != null) { handler.handle(serviceRequest); } else { serviceRequest.setServiceFault(StatusCodes.Bad_ServiceUnsupported); } } catch (UaException e) { serviceRequest.setServiceFault(e); } catch (Throwable t) { logger.error("Uncaught Throwable executing ServiceRequestHandler: {}", handler, t); serviceRequest.setServiceFault(StatusCodes.Bad_InternalError); } } @Override public ApplicationDescription getApplicationDescription() { return new ApplicationDescription(config.getApplicationUri(), config.getProductUri(), config.getApplicationName(), ApplicationType.Server, null, null, a(newArrayList(this.discoveryUrls), String.class)); } public List<Endpoint> getEndpoints() { return endpoints; } @Override public EndpointDescription[] getEndpointDescriptions() { return getEndpoints().stream().map(this::mapEndpoint).toArray(EndpointDescription[]::new); } @Override public SignedSoftwareCertificate[] getSoftwareCertificates() { List<SignedSoftwareCertificate> softwareCertificates = config.getSoftwareCertificates(); return softwareCertificates.toArray(new SignedSoftwareCertificate[softwareCertificates.size()]); } @Override public List<UserTokenPolicy> getUserTokenPolicies() { return config.getUserTokenPolicies(); } public List<String> getEndpointUrls() { return endpoints.stream().map(e -> e.getEndpointUri().toString()).collect(toList()); } public Set<String> getDiscoveryUrls() { return discoveryUrls; } @Override public CertificateManager getCertificateManager() { return config.getCertificateManager(); } @Override public CertificateValidator getCertificateValidator() { return config.getCertificateValidator(); } @Override public ExecutorService getExecutorService() { return config.getExecutor(); } @Override public ChannelConfig getChannelConfig() { return config.getChannelConfig(); } private long nextChannelId() { return channelIds.incrementAndGet(); } public long nextTokenId() { return tokenIds.incrementAndGet(); } @Override public ServerSecureChannel openSecureChannel() { ServerSecureChannel channel = new ServerSecureChannel(); channel.setChannelId(nextChannelId()); long channelId = channel.getChannelId(); secureChannels.put(channelId, channel); return channel; } @Override public void closeSecureChannel(ServerSecureChannel secureChannel) { long channelId = secureChannel.getChannelId(); if (secureChannels.remove(channelId) != null) { logger.debug("Removed secure channel id={}", channelId); } Channel channel = secureChannel.attr(BoundChannelKey).get(); if (channel != null) { logger.debug("Closing secure channel id={}, bound channel: {}", channelId, channel); channel.close(); } } public void secureChannelIssuedOrRenewed(ServerSecureChannel secureChannel, long lifetimeMillis) { long channelId = secureChannel.getChannelId(); /* * Cancel any existing timeouts and start a new one. */ Timeout timeout = timeouts.remove(channelId); boolean cancelled = (timeout == null || timeout.cancel()); if (cancelled) { timeout = wheelTimer.newTimeout(t -> closeSecureChannel(secureChannel), lifetimeMillis, TimeUnit.MILLISECONDS); timeouts.put(channelId, timeout); /* * If this is a reconnect there might be responses queued, so drain those. */ Channel channel = secureChannel.attr(BoundChannelKey).get(); if (channel != null) { List<ServiceResponse> responses = responseQueues.removeAll(channelId); responses.forEach(channel::write); channel.flush(); } } } public ServerSecureChannel getSecureChannel(long channelId) { return secureChannels.get(channelId); } @SuppressWarnings("unchecked") public <T extends UaRequestMessage, U extends UaResponseMessage> void addRequestHandler(Class<T> requestClass, ServiceRequestHandler<T, U> requestHandler) { ServiceRequestHandler<UaRequestMessage, UaResponseMessage> handler = (ServiceRequestHandler<UaRequestMessage, UaResponseMessage>) requestHandler; handlers.put(requestClass, handler); } @Override public UaTcpStackServer addEndpoint(String endpointUri, String bindAddress, X509Certificate certificate, SecurityPolicy securityPolicy, MessageSecurityMode messageSecurity) { boolean invalidConfiguration = messageSecurity == MessageSecurityMode.Invalid || (securityPolicy == SecurityPolicy.None && messageSecurity != MessageSecurityMode.None) || (securityPolicy != SecurityPolicy.None && messageSecurity == MessageSecurityMode.None); if (invalidConfiguration) { logger.warn("Invalid configuration, ignoring: {} + {}", securityPolicy, messageSecurity); } else { try { URI uri = new URI(endpointUri); endpoints.add(new Endpoint(uri, bindAddress, certificate, securityPolicy, messageSecurity)); } catch (URISyntaxException e) { logger.warn("Invalid endpoint URI, ignoring: {}", endpointUri); } } return this; } private EndpointDescription mapEndpoint(Endpoint endpoint) { List<UserTokenPolicy> userTokenPolicies = config.getUserTokenPolicies(); return new EndpointDescription(endpoint.getEndpointUri().toString(), getApplicationDescription(), certificateByteString(endpoint.getCertificate()), endpoint.getMessageSecurity(), endpoint.getSecurityPolicy().getSecurityPolicyUri(), userTokenPolicies.toArray(new UserTokenPolicy[userTokenPolicies.size()]), Stack.UA_TCP_BINARY_TRANSPORT_URI, ubyte(endpoint.getSecurityLevel())); } private ByteString certificateByteString(Optional<X509Certificate> certificate) { if (certificate.isPresent()) { try { return ByteString.of(certificate.get().getEncoded()); } catch (CertificateEncodingException e) { logger.error("Error decoding certificate.", e); return ByteString.NULL_VALUE; } } else { return ByteString.NULL_VALUE; } } private class DefaultDiscoveryServiceSet implements DiscoveryServiceSet { @Override public void onGetEndpoints(ServiceRequest<GetEndpointsRequest, GetEndpointsResponse> serviceRequest) { GetEndpointsRequest request = serviceRequest.getRequest(); List<String> profileUris = request.getProfileUris() != null ? newArrayList(request.getProfileUris()) : new ArrayList<>(); List<EndpointDescription> allEndpoints = endpoints.stream().map(UaTcpStackServer.this::mapEndpoint) .filter(ed -> filterProfileUris(ed, profileUris)).collect(toList()); List<EndpointDescription> matchingEndpoints = allEndpoints.stream() .filter(ed -> filterEndpointUrls(ed, request.getEndpointUrl())).collect(toList()); GetEndpointsResponse response = new GetEndpointsResponse(serviceRequest.createResponseHeader(), matchingEndpoints.isEmpty() ? a(allEndpoints, EndpointDescription.class) : a(matchingEndpoints, EndpointDescription.class)); serviceRequest.setResponse(response); } private boolean filterProfileUris(EndpointDescription endpoint, List<String> profileUris) { return profileUris.size() == 0 || profileUris.contains(endpoint.getTransportProfileUri()); } private boolean filterEndpointUrls(EndpointDescription endpoint, String endpointUrl) { try { String requestedHost = URI.create(endpointUrl).getHost(); String endpointHost = URI.create(endpoint.getEndpointUrl()).getHost(); return requestedHost.equalsIgnoreCase(endpointHost); } catch (Throwable t) { logger.warn("Unable to create URI.", t); return false; } } @Override public void onFindServers(ServiceRequest<FindServersRequest, FindServersResponse> serviceRequest) { FindServersRequest request = serviceRequest.getRequest(); List<String> serverUris = request.getServerUris() != null ? newArrayList(request.getServerUris()) : new ArrayList<>(); List<ApplicationDescription> applicationDescriptions = newArrayList( getApplicationDescription(request.getEndpointUrl())); applicationDescriptions = applicationDescriptions.stream() .filter(ad -> filterServerUris(ad, serverUris)).collect(toList()); FindServersResponse response = new FindServersResponse(serviceRequest.createResponseHeader(), a(applicationDescriptions, ApplicationDescription.class)); serviceRequest.setResponse(response); } private ApplicationDescription getApplicationDescription(String endpointUrl) { List<String> allDiscoveryUrls = newArrayList(discoveryUrls); List<String> matchingDiscoveryUrls = allDiscoveryUrls.stream().filter(discoveryUrl -> { try { String requestedHost = URI.create(endpointUrl).getHost(); String discoveryHost = URI.create(discoveryUrl).getHost(); return requestedHost.equalsIgnoreCase(discoveryHost); } catch (Throwable t) { logger.warn("Unable to create URI.", t); return false; } }).collect(toList()); return new ApplicationDescription(config.getApplicationUri(), config.getProductUri(), config.getApplicationName(), ApplicationType.Server, null, null, matchingDiscoveryUrls.isEmpty() ? a(allDiscoveryUrls, String.class) : a(matchingDiscoveryUrls, String.class)); } private boolean filterServerUris(ApplicationDescription ad, List<String> serverUris) { return serverUris.size() == 0 || serverUris.contains(ad.getApplicationUri()); } } }