001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.io.network.group.impl.driver; 020 021import org.apache.reef.driver.parameters.DriverIdentifier; 022import org.apache.reef.io.network.group.api.operators.GroupCommOperator; 023import org.apache.reef.io.network.group.api.GroupChanges; 024import org.apache.reef.io.network.group.api.config.OperatorSpec; 025import org.apache.reef.io.network.group.api.driver.TaskNode; 026import org.apache.reef.io.network.group.api.driver.Topology; 027import org.apache.reef.io.network.group.impl.GroupChangesCodec; 028import org.apache.reef.io.network.group.impl.GroupChangesImpl; 029import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; 030import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; 031import org.apache.reef.io.network.group.impl.config.GatherOperatorSpec; 032import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; 033import org.apache.reef.io.network.group.impl.config.ScatterOperatorSpec; 034import org.apache.reef.io.network.group.impl.config.parameters.*; 035import org.apache.reef.io.network.group.impl.operators.*; 036import org.apache.reef.io.network.group.impl.utils.Utils; 037import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos; 038import org.apache.reef.io.serialization.Codec; 039import org.apache.reef.tang.Configuration; 040import org.apache.reef.tang.JavaConfigurationBuilder; 041import org.apache.reef.tang.Tang; 042import org.apache.reef.tang.annotations.Name; 043import org.apache.reef.tang.annotations.Parameter; 044import org.apache.reef.tang.formats.AvroConfigurationSerializer; 045import org.apache.reef.tang.formats.ConfigurationSerializer; 046import org.apache.reef.wake.EStage; 047import org.apache.reef.wake.EventHandler; 048import org.apache.reef.wake.impl.SingleThreadStage; 049 050import javax.inject.Inject; 051import java.util.ArrayList; 052import java.util.List; 053import java.util.Map; 054import java.util.concurrent.ConcurrentMap; 055import java.util.concurrent.ConcurrentSkipListMap; 056import java.util.logging.Logger; 057 058/** 059 * Implements a tree topology with the specified Fan Out. 060 */ 061public final class TreeTopology implements Topology { 062 063 private static final Logger LOG = Logger.getLogger(TreeTopology.class.getName()); 064 065 private final EStage<GroupCommunicationMessage> senderStage; 066 private final Class<? extends Name<String>> groupName; 067 private final Class<? extends Name<String>> operName; 068 private final String driverId; 069 private String rootId; 070 private OperatorSpec operatorSpec; 071 072 private TaskNode root; 073 private TaskNode logicalRoot; 074 private TaskNode prev; 075 private final int fanOut; 076 077 private final ConcurrentMap<String, TaskNode> nodes = new ConcurrentSkipListMap<>(); 078 private final ConfigurationSerializer confSer = new AvroConfigurationSerializer(); 079 080 @Inject 081 private TreeTopology(@Parameter(GroupCommSenderStage.class) final EStage<GroupCommunicationMessage> senderStage, 082 @Parameter(CommGroupNameClass.class) final Class<? extends Name<String>> groupName, 083 @Parameter(OperatorNameClass.class) final Class<? extends Name<String>> operatorName, 084 @Parameter(DriverIdentifier.class) final String driverId, 085 @Parameter(TreeTopologyFanOut.class) final int fanOut) { 086 this.senderStage = senderStage; 087 this.groupName = groupName; 088 this.operName = operatorName; 089 this.driverId = driverId; 090 this.fanOut = fanOut; 091 LOG.config(getQualifiedName() + "Tree Topology running with a fan-out of " + fanOut); 092 } 093 094 @Override 095 @SuppressWarnings("checkstyle:hiddenfield") 096 public void setRootTask(final String rootId) { 097 LOG.entering("TreeTopology", "setRootTask", new Object[]{getQualifiedName(), rootId}); 098 this.rootId = rootId; 099 LOG.exiting("TreeTopology", "setRootTask", getQualifiedName() + rootId); 100 } 101 102 @Override 103 public String getRootId() { 104 LOG.entering("TreeTopology", "getRootId", getQualifiedName()); 105 LOG.exiting("TreeTopology", "getRootId", getQualifiedName() + rootId); 106 return rootId; 107 } 108 109 @Override 110 public boolean isRootPresent() { 111 LOG.entering("TreeTopology", "isRootPresent", getQualifiedName()); 112 final boolean retVal = root != null; 113 LOG.exiting("TreeTopology", "isRootPresent", String.format("%s%s", getQualifiedName(), retVal)); 114 return retVal; 115 } 116 117 @Override 118 public void setOperatorSpecification(final OperatorSpec spec) { 119 LOG.entering("TreeTopology", "setOperSpec", new Object[]{getQualifiedName(), spec}); 120 this.operatorSpec = spec; 121 LOG.exiting("TreeTopology", "setOperSpec", getQualifiedName() + spec); 122 } 123 124 @Override 125 public Configuration getTaskConfiguration(final String taskId) { 126 LOG.entering("TreeTopology", "getTaskConfig", new Object[]{getQualifiedName(), taskId}); 127 final TaskNode taskNode = nodes.get(taskId); 128 if (taskNode == null) { 129 throw new RuntimeException(getQualifiedName() + taskId + " does not exist"); 130 } 131 132 final int version = getNodeVersion(taskId); 133 final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder(); 134 jcb.bindNamedParameter(DataCodec.class, operatorSpec.getDataCodecClass()); 135 jcb.bindNamedParameter(TaskVersion.class, Integer.toString(version)); 136 if (operatorSpec instanceof BroadcastOperatorSpec) { 137 final BroadcastOperatorSpec broadcastOperatorSpec = (BroadcastOperatorSpec) operatorSpec; 138 if (taskId.equals(broadcastOperatorSpec.getSenderId())) { 139 jcb.bindImplementation(GroupCommOperator.class, BroadcastSender.class); 140 } else { 141 jcb.bindImplementation(GroupCommOperator.class, BroadcastReceiver.class); 142 } 143 } else if (operatorSpec instanceof ReduceOperatorSpec) { 144 final ReduceOperatorSpec reduceOperatorSpec = (ReduceOperatorSpec) operatorSpec; 145 jcb.bindNamedParameter(ReduceFunctionParam.class, reduceOperatorSpec.getRedFuncClass()); 146 if (taskId.equals(reduceOperatorSpec.getReceiverId())) { 147 jcb.bindImplementation(GroupCommOperator.class, ReduceReceiver.class); 148 } else { 149 jcb.bindImplementation(GroupCommOperator.class, ReduceSender.class); 150 } 151 } else if (operatorSpec instanceof ScatterOperatorSpec) { 152 final ScatterOperatorSpec scatterOperatorSpec = (ScatterOperatorSpec) operatorSpec; 153 if (taskId.equals(scatterOperatorSpec.getSenderId())) { 154 jcb.bindImplementation(GroupCommOperator.class, ScatterSender.class); 155 } else { 156 jcb.bindImplementation(GroupCommOperator.class, ScatterReceiver.class); 157 } 158 } else if (operatorSpec instanceof GatherOperatorSpec) { 159 final GatherOperatorSpec gatherOperatorSpec = (GatherOperatorSpec) operatorSpec; 160 if (taskId.equals(gatherOperatorSpec.getReceiverId())) { 161 jcb.bindImplementation(GroupCommOperator.class, GatherReceiver.class); 162 } else { 163 jcb.bindImplementation(GroupCommOperator.class, GatherSender.class); 164 } 165 } 166 final Configuration retConf = jcb.build(); 167 LOG.exiting("TreeTopology", "getTaskConfig", getQualifiedName() + confSer.toString(retConf)); 168 return retConf; 169 } 170 171 @Override 172 public int getNodeVersion(final String taskId) { 173 LOG.entering("TreeTopology", "getNodeVersion", new Object[]{getQualifiedName(), taskId}); 174 final TaskNode node = nodes.get(taskId); 175 if (node == null) { 176 throw new RuntimeException(getQualifiedName() + taskId + " is not available on the nodes map"); 177 } 178 final int version = node.getVersion(); 179 LOG.exiting("TreeTopology", "getNodeVersion", getQualifiedName() + " " + taskId + " " + version); 180 return version; 181 } 182 183 @Override 184 public void removeTask(final String taskId) { 185 LOG.entering("TreeTopology", "removeTask", new Object[]{getQualifiedName(), taskId}); 186 if (!nodes.containsKey(taskId)) { 187 LOG.fine("Trying to remove a non-existent node in the task graph"); 188 LOG.exiting("TreeTopology", "removeTask", getQualifiedName()); 189 return; 190 } 191 if (taskId.equals(rootId)) { 192 unsetRootNode(taskId); 193 } else { 194 removeChild(taskId); 195 } 196 LOG.exiting("TreeTopology", "removeTask", getQualifiedName() + taskId); 197 } 198 199 @Override 200 public void addTask(final String taskId) { 201 LOG.entering("TreeTopology", "addTask", new Object[]{getQualifiedName(), taskId}); 202 if (nodes.containsKey(taskId)) { 203 LOG.fine("Got a request to add a task that is already in the graph. " + 204 "We need to block this request till the delete finishes. ***CAUTION***"); 205 } 206 207 if (taskId.equals(rootId)) { 208 setRootNode(taskId); 209 } else { 210 addChild(taskId); 211 } 212 LOG.exiting("TreeTopology", "addTask", getQualifiedName() + taskId); 213 } 214 215 private void addChild(final String taskId) { 216 LOG.entering("TreeTopology", "addChild", new Object[]{getQualifiedName(), taskId}); 217 LOG.finest(getQualifiedName() + "Adding leaf " + taskId); 218 final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, taskId, driverId, false); 219 if (logicalRoot != null) { 220 addTaskNode(node); 221 prev = node; 222 } 223 nodes.put(taskId, node); 224 LOG.exiting("TreeTopology", "addChild", getQualifiedName() + taskId); 225 } 226 227 private void addTaskNode(final TaskNode node) { 228 LOG.entering("TreeTopology", "addTaskNode", new Object[]{getQualifiedName(), node}); 229 if (logicalRoot.getNumberOfChildren() >= this.fanOut) { 230 logicalRoot = logicalRoot.successor(); 231 } 232 node.setParent(logicalRoot); 233 logicalRoot.addChild(node); 234 prev.setSibling(node); 235 LOG.exiting("TreeTopology", "addTaskNode", getQualifiedName() + node); 236 } 237 238 private void removeChild(final String taskId) { 239 LOG.entering("TreeTopology", "removeChild", new Object[]{getQualifiedName(), taskId}); 240 if (root != null) { 241 root.removeChild(nodes.get(taskId)); 242 } 243 nodes.remove(taskId); 244 LOG.exiting("TreeTopology", "removeChild", getQualifiedName() + taskId); 245 } 246 247 private void setRootNode(final String newRootId) { 248 LOG.entering("TreeTopology", "setRootNode", new Object[]{getQualifiedName(), newRootId}); 249 this.root = new TaskNodeImpl(senderStage, groupName, operName, newRootId, driverId, true); 250 this.logicalRoot = this.root; 251 this.prev = this.root; 252 253 for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) { 254 final TaskNode leaf = nodeEntry.getValue(); 255 addTaskNode(leaf); 256 this.prev = leaf; 257 } 258 nodes.put(newRootId, root); 259 LOG.exiting("TreeTopology", "setRootNode", getQualifiedName() + newRootId); 260 } 261 262 private void unsetRootNode(final String taskId) { 263 LOG.entering("TreeTopology", "unsetRootNode", new Object[]{getQualifiedName(), taskId}); 264 nodes.remove(rootId); 265 root = null; 266 267 for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) { 268 final TaskNode leaf = nodeEntry.getValue(); 269 leaf.setParent(null); 270 } 271 LOG.exiting("TreeTopology", "unsetRootNode", getQualifiedName() + taskId); 272 } 273 274 @Override 275 public void onFailedTask(final String taskId) { 276 LOG.entering("TreeTopology", "onFailedTask", new Object[]{getQualifiedName(), taskId}); 277 final TaskNode taskNode = nodes.get(taskId); 278 if (taskNode == null) { 279 throw new RuntimeException(getQualifiedName() + taskId + " does not exist"); 280 } 281 taskNode.onFailedTask(); 282 LOG.exiting("TreeTopology", "onFailedTask", getQualifiedName() + taskId); 283 } 284 285 @Override 286 public void onRunningTask(final String taskId) { 287 LOG.entering("TreeTopology", "onRunningTask", new Object[]{getQualifiedName(), taskId}); 288 final TaskNode taskNode = nodes.get(taskId); 289 if (taskNode == null) { 290 throw new RuntimeException(getQualifiedName() + taskId + " does not exist"); 291 } 292 taskNode.onRunningTask(); 293 LOG.exiting("TreeTopology", "onRunningTask", getQualifiedName() + taskId); 294 } 295 296 @Override 297 public void onReceiptOfMessage(final GroupCommunicationMessage msg) { 298 LOG.entering("TreeTopology", "onReceiptOfMessage", new Object[]{getQualifiedName(), msg}); 299 switch (msg.getType()) { 300 case TopologyChanges: 301 onTopologyChanges(msg); 302 break; 303 case UpdateTopology: 304 onUpdateTopology(msg); 305 break; 306 307 default: 308 nodes.get(msg.getSrcid()).onReceiptOfAcknowledgement(msg); 309 break; 310 } 311 LOG.exiting("TreeTopology", "onReceiptOfMessage", getQualifiedName() + msg); 312 } 313 314 private void onUpdateTopology(final GroupCommunicationMessage msg) { 315 LOG.entering("TreeTopology", "onUpdateTopology", new Object[]{getQualifiedName(), msg}); 316 LOG.fine(getQualifiedName() + "Update affected parts of Topology"); 317 final String dstId = msg.getSrcid(); 318 final int version = getNodeVersion(dstId); 319 320 LOG.finest(getQualifiedName() + "Creating NodeTopologyUpdateWaitStage to wait on nodes to be updated"); 321 final EventHandler<List<TaskNode>> topoUpdateWaitHandler = new TopologyUpdateWaitHandler(senderStage, groupName, 322 operName, driverId, 0, 323 dstId, version, 324 getQualifiedName(), TopologySerializer.encode(root)); 325 final EStage<List<TaskNode>> nodeTopologyUpdateWaitStage = new SingleThreadStage<>("NodeTopologyUpdateWaitStage", 326 topoUpdateWaitHandler, 327 nodes.size()); 328 329 final List<TaskNode> toBeUpdatedNodes = new ArrayList<>(nodes.size()); 330 LOG.finest(getQualifiedName() + "Checking which nodes need to be updated"); 331 for (final TaskNode node : nodes.values()) { 332 if (node.isRunning() && node.hasChanges() && node.resetTopologySetupSent()) { 333 toBeUpdatedNodes.add(node); 334 } 335 } 336 for (final TaskNode node : toBeUpdatedNodes) { 337 node.updatingTopology(); 338 LOG.fine(getQualifiedName() + "Asking " + node + " to UpdateTopology"); 339 senderStage.onNext(Utils.bldVersionedGCM(groupName, operName, 340 ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, driverId, 0, node.getTaskId(), 341 node.getVersion(), Utils.EMPTY_BYTE_ARR)); 342 } 343 nodeTopologyUpdateWaitStage.onNext(toBeUpdatedNodes); 344 LOG.exiting("TreeTopology", "onUpdateTopology", getQualifiedName() + msg); 345 } 346 347 private void onTopologyChanges(final GroupCommunicationMessage msg) { 348 LOG.entering("TreeTopology", "onTopologyChanges", new Object[]{getQualifiedName(), msg}); 349 LOG.fine(getQualifiedName() + "Check TopologyChanges"); 350 final String dstId = msg.getSrcid(); 351 boolean hasTopologyChanged = false; 352 LOG.finest(getQualifiedName() + "Checking which nodes need to be updated"); 353 for (final TaskNode node : nodes.values()) { 354 if (!node.isRunning() || node.hasChanges()) { 355 hasTopologyChanged = true; 356 break; 357 } 358 } 359 final GroupChanges changes = new GroupChangesImpl(hasTopologyChanged); 360 final Codec<GroupChanges> changesCodec = new GroupChangesCodec(); 361 LOG.fine(getQualifiedName() + "TopologyChanges: " + changes); 362 senderStage.onNext(Utils.bldVersionedGCM(groupName, operName, 363 ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, driverId, 0, dstId, getNodeVersion(dstId), 364 changesCodec.encode(changes))); 365 LOG.exiting("TreeTopology", "onTopologyChanges", getQualifiedName() + msg); 366 } 367 368 private String getQualifiedName() { 369 return Utils.simpleName(groupName) + ":" + Utils.simpleName(operName) + " - "; 370 } 371}