001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.io.network.group.impl.task; 020 021import org.apache.reef.driver.task.TaskConfigurationOptions; 022import org.apache.reef.exception.evaluator.NetworkException; 023import org.apache.reef.io.network.group.api.operators.*; 024import org.apache.reef.io.network.group.impl.config.parameters.DriverIdentifierGroupComm; 025import org.apache.reef.io.network.group.impl.driver.TopologySimpleNode; 026import org.apache.reef.io.network.group.impl.driver.TopologySerializer; 027import org.apache.reef.io.network.impl.NetworkService; 028import org.apache.reef.io.network.group.api.GroupChanges; 029import org.apache.reef.io.network.group.api.task.CommGroupNetworkHandler; 030import org.apache.reef.io.network.group.api.task.CommunicationGroupServiceClient; 031import org.apache.reef.io.network.group.api.task.GroupCommNetworkHandler; 032import org.apache.reef.io.network.group.impl.GroupChangesCodec; 033import org.apache.reef.io.network.group.impl.GroupChangesImpl; 034import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; 035import org.apache.reef.io.network.group.impl.config.parameters.CommunicationGroupName; 036import org.apache.reef.io.network.group.impl.config.parameters.OperatorName; 037import org.apache.reef.io.network.group.impl.config.parameters.SerializedOperConfigs; 038import org.apache.reef.io.network.group.impl.operators.Sender; 039import org.apache.reef.io.network.group.impl.utils.Utils; 040import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos; 041import org.apache.reef.io.network.util.Pair; 042import org.apache.reef.io.serialization.Codec; 043import org.apache.reef.tang.Configuration; 044import org.apache.reef.tang.Injector; 045import org.apache.reef.tang.Tang; 046import org.apache.reef.tang.annotations.Name; 047import org.apache.reef.tang.annotations.Parameter; 048import org.apache.reef.tang.exceptions.InjectionException; 049import org.apache.reef.tang.formats.ConfigurationSerializer; 050import org.apache.reef.wake.EStage; 051import org.apache.reef.wake.Identifier; 052import org.apache.reef.wake.IdentifierFactory; 053import org.apache.reef.wake.impl.ThreadPoolStage; 054 055import javax.inject.Inject; 056import java.io.IOException; 057import java.util.*; 058import java.util.concurrent.CountDownLatch; 059import java.util.concurrent.atomic.AtomicBoolean; 060import java.util.logging.Logger; 061 062public class CommunicationGroupClientImpl implements CommunicationGroupServiceClient { 063 private static final Logger LOG = Logger.getLogger(CommunicationGroupClientImpl.class.getName()); 064 065 private final GroupCommNetworkHandler groupCommNetworkHandler; 066 private final Class<? extends Name<String>> groupName; 067 private final Map<Class<? extends Name<String>>, GroupCommOperator> operators; 068 private final Sender sender; 069 070 private final String taskId; 071 private final boolean isScatterSender; 072 private final IdentifierFactory identifierFactory; 073 private List<Identifier> activeSlaveTasks; 074 private TopologySimpleNode topologySimpleNodeRoot; 075 076 private final String driverId; 077 078 private final CommGroupNetworkHandler commGroupNetworkHandler; 079 080 private final AtomicBoolean init = new AtomicBoolean(false); 081 082 /** 083 * @deprecated in 0.14. 084 * Use the private constructor that receives an {@code injector} as a parameter instead. 085 */ 086 @Deprecated 087 @Inject 088 public CommunicationGroupClientImpl(@Parameter(CommunicationGroupName.class) final String groupName, 089 @Parameter(TaskConfigurationOptions.Identifier.class) final String taskId, 090 @Parameter(DriverIdentifierGroupComm.class) final String driverId, 091 final GroupCommNetworkHandler groupCommNetworkHandler, 092 @Parameter(SerializedOperConfigs.class) final Set<String> operatorConfigs, 093 final ConfigurationSerializer configSerializer, 094 final NetworkService<GroupCommunicationMessage> netService) { 095 this.taskId = taskId; 096 this.driverId = driverId; 097 LOG.finest(groupName + " has GroupCommHandler-" + groupCommNetworkHandler.toString()); 098 this.identifierFactory = netService.getIdentifierFactory(); 099 this.groupName = Utils.getClass(groupName); 100 this.groupCommNetworkHandler = groupCommNetworkHandler; 101 this.sender = new Sender(netService); 102 this.operators = new TreeMap<>(new Comparator<Class<? extends Name<String>>>() { 103 104 @Override 105 public int compare(final Class<? extends Name<String>> o1, final Class<? extends Name<String>> o2) { 106 final String s1 = o1.getSimpleName(); 107 final String s2 = o2.getSimpleName(); 108 return s1.compareTo(s2); 109 } 110 }); 111 try { 112 this.commGroupNetworkHandler = Tang.Factory.getTang().newInjector().getInstance(CommGroupNetworkHandler.class); 113 this.groupCommNetworkHandler.register(this.groupName, commGroupNetworkHandler); 114 115 boolean operatorIsScatterSender = false; 116 for (final String operatorConfigStr : operatorConfigs) { 117 118 final Configuration operatorConfig = configSerializer.fromString(operatorConfigStr); 119 final Injector injector = Tang.Factory.getTang().newInjector(operatorConfig); 120 121 injector.bindVolatileParameter(TaskConfigurationOptions.Identifier.class, taskId); 122 injector.bindVolatileParameter(CommunicationGroupName.class, groupName); 123 injector.bindVolatileInstance(CommGroupNetworkHandler.class, commGroupNetworkHandler); 124 injector.bindVolatileInstance(NetworkService.class, netService); 125 injector.bindVolatileInstance(CommunicationGroupServiceClient.class, this); 126 127 final GroupCommOperator operator = injector.getInstance(GroupCommOperator.class); 128 final String operName = injector.getNamedInstance(OperatorName.class); 129 this.operators.put(Utils.getClass(operName), operator); 130 LOG.finest(operName + " has CommGroupHandler-" + commGroupNetworkHandler.toString()); 131 132 if (!operatorIsScatterSender && operator instanceof Scatter.Sender) { 133 LOG.fine(operName + " is a scatter sender. Will keep track of active slave tasks."); 134 operatorIsScatterSender = true; 135 } 136 } 137 this.isScatterSender = operatorIsScatterSender; 138 } catch (final InjectionException | IOException e) { 139 throw new RuntimeException("Unable to deserialize operator config", e); 140 } 141 } 142 143 @Inject 144 private CommunicationGroupClientImpl(@Parameter(CommunicationGroupName.class) final String groupName, 145 @Parameter(TaskConfigurationOptions.Identifier.class) final String taskId, 146 @Parameter(DriverIdentifierGroupComm.class) final String driverId, 147 final GroupCommNetworkHandler groupCommNetworkHandler, 148 @Parameter(SerializedOperConfigs.class) final Set<String> operatorConfigs, 149 final ConfigurationSerializer configSerializer, 150 final NetworkService<GroupCommunicationMessage> netService, 151 final CommGroupNetworkHandler commGroupNetworkHandler, 152 final Injector injector) { 153 this.taskId = taskId; 154 this.driverId = driverId; 155 LOG.finest(groupName + " has GroupCommHandler-" + groupCommNetworkHandler.toString()); 156 this.identifierFactory = netService.getIdentifierFactory(); 157 this.groupName = Utils.getClass(groupName); 158 this.groupCommNetworkHandler = groupCommNetworkHandler; 159 this.commGroupNetworkHandler = commGroupNetworkHandler; 160 this.sender = new Sender(netService); 161 this.operators = new TreeMap<>(new Comparator<Class<? extends Name<String>>>() { 162 163 @Override 164 public int compare(final Class<? extends Name<String>> o1, final Class<? extends Name<String>> o2) { 165 final String s1 = o1.getSimpleName(); 166 final String s2 = o2.getSimpleName(); 167 return s1.compareTo(s2); 168 } 169 }); 170 try { 171 this.groupCommNetworkHandler.register(this.groupName, commGroupNetworkHandler); 172 173 boolean operatorIsScatterSender = false; 174 for (final String operatorConfigStr : operatorConfigs) { 175 176 final Configuration operatorConfig = configSerializer.fromString(operatorConfigStr); 177 final Injector forkedInjector = injector.forkInjector(operatorConfig); 178 179 forkedInjector.bindVolatileInstance(CommunicationGroupServiceClient.class, this); 180 181 final GroupCommOperator operator = forkedInjector.getInstance(GroupCommOperator.class); 182 final String operName = forkedInjector.getNamedInstance(OperatorName.class); 183 this.operators.put(Utils.getClass(operName), operator); 184 LOG.finest(operName + " has CommGroupHandler-" + commGroupNetworkHandler.toString()); 185 186 if (!operatorIsScatterSender && operator instanceof Scatter.Sender) { 187 LOG.fine(operName + " is a scatter sender. Will keep track of active slave tasks."); 188 operatorIsScatterSender = true; 189 } 190 } 191 this.isScatterSender = operatorIsScatterSender; 192 } catch (final InjectionException | IOException e) { 193 throw new RuntimeException("Unable to deserialize operator config", e); 194 } 195 } 196 197 @Override 198 public Broadcast.Sender getBroadcastSender(final Class<? extends Name<String>> operatorName) { 199 LOG.entering("CommunicationGroupClientImpl", "getBroadcastSender", new Object[]{getQualifiedName(), 200 Utils.simpleName(operatorName)}); 201 final GroupCommOperator op = operators.get(operatorName); 202 if (!(op instanceof Broadcast.Sender)) { 203 throw new RuntimeException("Configured operator is not a broadcast sender"); 204 } 205 commGroupNetworkHandler.addTopologyElement(operatorName); 206 LOG.exiting("CommunicationGroupClientImpl", "getBroadcastSender", getQualifiedName() + op); 207 return (Broadcast.Sender) op; 208 } 209 210 @Override 211 public Reduce.Receiver getReduceReceiver(final Class<? extends Name<String>> operatorName) { 212 LOG.entering("CommunicationGroupClientImpl", "getReduceReceiver", new Object[]{getQualifiedName(), 213 Utils.simpleName(operatorName)}); 214 final GroupCommOperator op = operators.get(operatorName); 215 if (!(op instanceof Reduce.Receiver)) { 216 throw new RuntimeException("Configured operator is not a reduce receiver"); 217 } 218 commGroupNetworkHandler.addTopologyElement(operatorName); 219 LOG.exiting("CommunicationGroupClientImpl", "getReduceReceiver", getQualifiedName() + op); 220 return (Reduce.Receiver) op; 221 } 222 223 @Override 224 public Scatter.Sender getScatterSender(final Class<? extends Name<String>> operatorName) { 225 LOG.entering("CommunicationGroupClientImpl", "getScatterSender", new Object[]{getQualifiedName(), 226 Utils.simpleName(operatorName)}); 227 final GroupCommOperator op = operators.get(operatorName); 228 if (!(op instanceof Scatter.Sender)) { 229 throw new RuntimeException("Configured operator is not a scatter sender"); 230 } 231 commGroupNetworkHandler.addTopologyElement(operatorName); 232 LOG.exiting("CommunicationGroupClientImpl", "getScatterSender", getQualifiedName() + op); 233 return (Scatter.Sender) op; 234 } 235 236 @Override 237 public Gather.Receiver getGatherReceiver(final Class<? extends Name<String>> operatorName) { 238 LOG.entering("CommunicationGroupClientImpl", "getGatherReceiver", new Object[]{getQualifiedName(), 239 Utils.simpleName(operatorName)}); 240 final GroupCommOperator op = operators.get(operatorName); 241 if (!(op instanceof Gather.Receiver)) { 242 throw new RuntimeException("Configured operator is not a gather receiver"); 243 } 244 commGroupNetworkHandler.addTopologyElement(operatorName); 245 LOG.exiting("CommunicationGroupClientImpl", "getGatherReceiver", getQualifiedName() + op); 246 return (Gather.Receiver) op; 247 } 248 249 250 @Override 251 public Broadcast.Receiver getBroadcastReceiver(final Class<? extends Name<String>> operatorName) { 252 LOG.entering("CommunicationGroupClientImpl", "getBroadcastReceiver", new Object[]{getQualifiedName(), 253 Utils.simpleName(operatorName)}); 254 final GroupCommOperator op = operators.get(operatorName); 255 if (!(op instanceof Broadcast.Receiver)) { 256 throw new RuntimeException("Configured operator is not a broadcast receiver"); 257 } 258 commGroupNetworkHandler.addTopologyElement(operatorName); 259 LOG.exiting("CommunicationGroupClientImpl", "getBroadcastReceiver", getQualifiedName() + op); 260 return (Broadcast.Receiver) op; 261 } 262 263 @Override 264 public Reduce.Sender getReduceSender(final Class<? extends Name<String>> operatorName) { 265 LOG.entering("CommunicationGroupClientImpl", "getReduceSender", new Object[]{getQualifiedName(), 266 Utils.simpleName(operatorName)}); 267 final GroupCommOperator op = operators.get(operatorName); 268 if (!(op instanceof Reduce.Sender)) { 269 throw new RuntimeException("Configured operator is not a reduce sender"); 270 } 271 commGroupNetworkHandler.addTopologyElement(operatorName); 272 LOG.exiting("CommunicationGroupClientImpl", "getReduceSender", getQualifiedName() + op); 273 return (Reduce.Sender) op; 274 } 275 276 @Override 277 public Scatter.Receiver getScatterReceiver(final Class<? extends Name<String>> operatorName) { 278 LOG.entering("CommunicationGroupClientImpl", "getScatterReceiver", new Object[]{getQualifiedName(), 279 Utils.simpleName(operatorName)}); 280 final GroupCommOperator op = operators.get(operatorName); 281 if (!(op instanceof Scatter.Receiver)) { 282 throw new RuntimeException("Configured operator is not a scatter receiver"); 283 } 284 commGroupNetworkHandler.addTopologyElement(operatorName); 285 LOG.exiting("CommunicationGroupClientImpl", "getScatterReceiver", getQualifiedName() + op); 286 return (Scatter.Receiver) op; 287 } 288 289 @Override 290 public Gather.Sender getGatherSender(final Class<? extends Name<String>> operatorName) { 291 LOG.entering("CommunicationGroupClientImpl", "getGatherSender", new Object[]{getQualifiedName(), 292 Utils.simpleName(operatorName)}); 293 final GroupCommOperator op = operators.get(operatorName); 294 if (!(op instanceof Gather.Sender)) { 295 throw new RuntimeException("Configured operator is not a gather sender"); 296 } 297 commGroupNetworkHandler.addTopologyElement(operatorName); 298 LOG.exiting("CommunicationGroupClientImpl", "getGatherSender", getQualifiedName() + op); 299 return (Gather.Sender) op; 300 } 301 302 @Override 303 public void initialize() { 304 LOG.entering("CommunicationGroupClientImpl", "initialize", getQualifiedName()); 305 if (init.compareAndSet(false, true)) { 306 LOG.finest("CommGroup-" + groupName + " is initializing"); 307 final CountDownLatch initLatch = new CountDownLatch(operators.size()); 308 309 final InitHandler initHandler = new InitHandler(initLatch); 310 final EStage<GroupCommOperator> initStage = new ThreadPoolStage<>(initHandler, operators.size()); 311 for (final GroupCommOperator op : operators.values()) { 312 initStage.onNext(op); 313 } 314 315 try { 316 initLatch.await(); 317 } catch (final InterruptedException e) { 318 throw new RuntimeException("InterruptedException while waiting for initialization", e); 319 } 320 321 if (isScatterSender) { 322 updateTopology(); 323 } 324 325 if (initHandler.getException() != null) { 326 throw new RuntimeException(getQualifiedName() + "Parent dead. Current behavior is for the child to die too."); 327 } 328 } 329 LOG.exiting("CommunicationGroupClientImpl", "initialize", getQualifiedName()); 330 } 331 332 @Override 333 public GroupChanges getTopologyChanges() { 334 LOG.entering("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName()); 335 for (final GroupCommOperator op : operators.values()) { 336 final Class<? extends Name<String>> operName = op.getOperName(); 337 LOG.finest("Sending TopologyChanges msg to driver"); 338 try { 339 sender.send(Utils.bldVersionedGCM(groupName, operName, 340 ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, taskId, op.getVersion(), driverId, 341 0, Utils.EMPTY_BYTE_ARR)); 342 } catch (final NetworkException e) { 343 throw new RuntimeException("NetworkException while sending GetTopologyChanges", e); 344 } 345 } 346 final Codec<GroupChanges> changesCodec = new GroupChangesCodec(); 347 final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges = new HashMap<>(); 348 for (final GroupCommOperator op : operators.values()) { 349 final Class<? extends Name<String>> operName = op.getOperName(); 350 final byte[] changes = commGroupNetworkHandler.waitForTopologyChanges(operName); 351 perOpChanges.put(operName, changesCodec.decode(changes)); 352 } 353 final GroupChanges retVal = mergeGroupChanges(perOpChanges); 354 LOG.exiting("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName() + retVal); 355 return retVal; 356 } 357 358 /** 359 * @param perOpChanges 360 * @return 361 */ 362 private GroupChanges mergeGroupChanges(final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges) { 363 LOG.entering("CommunicationGroupClientImpl", "mergeGroupChanges", new Object[]{getQualifiedName(), perOpChanges}); 364 boolean doChangesExist = false; 365 for (final GroupChanges change : perOpChanges.values()) { 366 if (change.exist()) { 367 doChangesExist = true; 368 break; 369 } 370 } 371 final GroupChanges changes = new GroupChangesImpl(doChangesExist); 372 LOG.exiting("CommunicationGroupClientImpl", "mergeGroupChanges", getQualifiedName() + changes); 373 return changes; 374 } 375 376 @Override 377 public void updateTopology() { 378 LOG.entering("CommunicationGroupClientImpl", "updateTopology", getQualifiedName()); 379 for (final GroupCommOperator op : operators.values()) { 380 final Class<? extends Name<String>> operName = op.getOperName(); 381 try { 382 sender.send(Utils.bldVersionedGCM(groupName, operName, 383 ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, taskId, op.getVersion(), driverId, 384 0, Utils.EMPTY_BYTE_ARR)); 385 } catch (final NetworkException e) { 386 throw new RuntimeException("NetworkException while sending UpdateTopology", e); 387 } 388 } 389 for (final GroupCommOperator op : operators.values()) { 390 final Class<? extends Name<String>> operName = op.getOperName(); 391 GroupCommunicationMessage msg; 392 do { 393 msg = commGroupNetworkHandler.waitForTopologyUpdate(operName); 394 } while (!isMsgVersionOk(msg)); 395 396 if (isScatterSender) { 397 updateActiveTasks(msg); 398 } 399 } 400 LOG.exiting("CommunicationGroupClientImpl", "updateTopology", getQualifiedName()); 401 } 402 403 private void updateActiveTasks(final GroupCommunicationMessage msg) { 404 LOG.entering("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg}); 405 406 final Pair<TopologySimpleNode, List<Identifier>> pair = 407 TopologySerializer.decode(msg.getData()[0], identifierFactory); 408 409 topologySimpleNodeRoot = pair.getFirst(); 410 411 activeSlaveTasks = pair.getSecond(); 412 // remove myself 413 activeSlaveTasks.remove(identifierFactory.getNewInstance(taskId)); 414 // sort the tasks in lexicographical order on task ids 415 Collections.sort(activeSlaveTasks, new Comparator<Identifier>() { 416 @Override 417 public int compare(final Identifier o1, final Identifier o2) { 418 return o1.toString().compareTo(o2.toString()); 419 } 420 }); 421 422 LOG.exiting("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg}); 423 } 424 425 private boolean isMsgVersionOk(final GroupCommunicationMessage msg) { 426 LOG.entering("CommunicationGroupClientImpl", "isMsgVersionOk", new Object[]{getQualifiedName(), msg}); 427 if (msg.hasVersion()) { 428 final int msgVersion = msg.getVersion(); 429 final GroupCommOperator operator = operators.get(Utils.getClass(msg.getOperatorname())); 430 final int nodeVersion = operator.getVersion(); 431 final boolean retVal; 432 if (msgVersion < nodeVersion) { 433 LOG.warning(getQualifiedName() + "Received a ver-" + msgVersion + " msg while expecting ver-" + nodeVersion 434 + ". Discarding msg"); 435 retVal = false; 436 } else { 437 retVal = true; 438 } 439 LOG.exiting("CommunicationGroupClientImpl", "isMsgVersionOk", 440 Arrays.toString(new Object[]{retVal, getQualifiedName(), msg})); 441 return retVal; 442 } else { 443 throw new RuntimeException(getQualifiedName() + "can only deal with versioned msgs"); 444 } 445 } 446 447 @Override 448 public List<Identifier> getActiveSlaveTasks() { 449 return this.activeSlaveTasks; 450 } 451 452 @Override 453 public TopologySimpleNode getTopologySimpleNodeRoot() { 454 return this.topologySimpleNodeRoot; 455 } 456 457 private String getQualifiedName() { 458 return Utils.simpleName(groupName) + " "; 459 } 460 461 @Override 462 public Class<? extends Name<String>> getName() { 463 return groupName; 464 } 465 466}