001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.io.network.group.impl.task; 020 021import org.apache.reef.driver.task.TaskConfigurationOptions; 022import org.apache.reef.exception.evaluator.NetworkException; 023import org.apache.reef.io.network.group.api.operators.*; 024import org.apache.reef.io.network.group.impl.config.parameters.DriverIdentifierGroupComm; 025import org.apache.reef.io.network.group.impl.driver.TopologySimpleNode; 026import org.apache.reef.io.network.group.impl.driver.TopologySerializer; 027import org.apache.reef.io.network.impl.NetworkService; 028import org.apache.reef.io.network.group.api.GroupChanges; 029import org.apache.reef.io.network.group.api.task.CommGroupNetworkHandler; 030import org.apache.reef.io.network.group.api.task.CommunicationGroupServiceClient; 031import org.apache.reef.io.network.group.api.task.GroupCommNetworkHandler; 032import org.apache.reef.io.network.group.impl.GroupChangesCodec; 033import org.apache.reef.io.network.group.impl.GroupChangesImpl; 034import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; 035import org.apache.reef.io.network.group.impl.config.parameters.CommunicationGroupName; 036import org.apache.reef.io.network.group.impl.config.parameters.OperatorName; 037import org.apache.reef.io.network.group.impl.config.parameters.SerializedOperConfigs; 038import org.apache.reef.io.network.group.impl.operators.Sender; 039import org.apache.reef.io.network.group.impl.utils.Utils; 040import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos; 041import org.apache.reef.io.network.util.Pair; 042import org.apache.reef.io.serialization.Codec; 043import org.apache.reef.tang.Configuration; 044import org.apache.reef.tang.Injector; 045import org.apache.reef.tang.annotations.Name; 046import org.apache.reef.tang.annotations.Parameter; 047import org.apache.reef.tang.exceptions.InjectionException; 048import org.apache.reef.tang.formats.ConfigurationSerializer; 049import org.apache.reef.wake.EStage; 050import org.apache.reef.wake.Identifier; 051import org.apache.reef.wake.IdentifierFactory; 052import org.apache.reef.wake.impl.ThreadPoolStage; 053 054import javax.inject.Inject; 055import java.io.IOException; 056import java.util.*; 057import java.util.concurrent.CountDownLatch; 058import java.util.concurrent.atomic.AtomicBoolean; 059import java.util.logging.Logger; 060 061public final class CommunicationGroupClientImpl implements CommunicationGroupServiceClient { 062 private static final Logger LOG = Logger.getLogger(CommunicationGroupClientImpl.class.getName()); 063 064 private final GroupCommNetworkHandler groupCommNetworkHandler; 065 private final Class<? extends Name<String>> groupName; 066 private final Map<Class<? extends Name<String>>, GroupCommOperator> operators; 067 private final Sender sender; 068 069 private final String taskId; 070 private final boolean isScatterSender; 071 private final IdentifierFactory identifierFactory; 072 private List<Identifier> activeSlaveTasks; 073 private TopologySimpleNode topologySimpleNodeRoot; 074 075 private final String driverId; 076 077 private final CommGroupNetworkHandler commGroupNetworkHandler; 078 079 private final AtomicBoolean init = new AtomicBoolean(false); 080 081 @Inject 082 private CommunicationGroupClientImpl(@Parameter(CommunicationGroupName.class) final String groupName, 083 @Parameter(TaskConfigurationOptions.Identifier.class) final String taskId, 084 @Parameter(DriverIdentifierGroupComm.class) final String driverId, 085 final GroupCommNetworkHandler groupCommNetworkHandler, 086 @Parameter(SerializedOperConfigs.class) final Set<String> operatorConfigs, 087 final ConfigurationSerializer configSerializer, 088 final NetworkService<GroupCommunicationMessage> netService, 089 final CommGroupNetworkHandler commGroupNetworkHandler, 090 final Injector injector) { 091 this.taskId = taskId; 092 this.driverId = driverId; 093 LOG.finest(groupName + " has GroupCommHandler-" + groupCommNetworkHandler.toString()); 094 this.identifierFactory = netService.getIdentifierFactory(); 095 this.groupName = Utils.getClass(groupName); 096 this.groupCommNetworkHandler = groupCommNetworkHandler; 097 this.commGroupNetworkHandler = commGroupNetworkHandler; 098 this.sender = new Sender(netService); 099 this.operators = new TreeMap<>(new Comparator<Class<? extends Name<String>>>() { 100 101 @Override 102 public int compare(final Class<? extends Name<String>> o1, final Class<? extends Name<String>> o2) { 103 final String s1 = o1.getSimpleName(); 104 final String s2 = o2.getSimpleName(); 105 return s1.compareTo(s2); 106 } 107 }); 108 try { 109 this.groupCommNetworkHandler.register(this.groupName, commGroupNetworkHandler); 110 111 boolean operatorIsScatterSender = false; 112 for (final String operatorConfigStr : operatorConfigs) { 113 114 final Configuration operatorConfig = configSerializer.fromString(operatorConfigStr); 115 final Injector forkedInjector = injector.forkInjector(operatorConfig); 116 117 forkedInjector.bindVolatileInstance(CommunicationGroupServiceClient.class, this); 118 119 final GroupCommOperator operator = forkedInjector.getInstance(GroupCommOperator.class); 120 final String operName = forkedInjector.getNamedInstance(OperatorName.class); 121 this.operators.put(Utils.getClass(operName), operator); 122 LOG.finest(operName + " has CommGroupHandler-" + commGroupNetworkHandler.toString()); 123 124 if (!operatorIsScatterSender && operator instanceof Scatter.Sender) { 125 LOG.fine(operName + " is a scatter sender. Will keep track of active slave tasks."); 126 operatorIsScatterSender = true; 127 } 128 } 129 this.isScatterSender = operatorIsScatterSender; 130 } catch (final InjectionException | IOException e) { 131 throw new RuntimeException("Unable to deserialize operator config", e); 132 } 133 } 134 135 @Override 136 public Broadcast.Sender getBroadcastSender(final Class<? extends Name<String>> operatorName) { 137 LOG.entering("CommunicationGroupClientImpl", "getBroadcastSender", new Object[]{getQualifiedName(), 138 Utils.simpleName(operatorName)}); 139 final GroupCommOperator op = operators.get(operatorName); 140 if (!(op instanceof Broadcast.Sender)) { 141 throw new RuntimeException("Configured operator is not a broadcast sender"); 142 } 143 commGroupNetworkHandler.addTopologyElement(operatorName); 144 LOG.exiting("CommunicationGroupClientImpl", "getBroadcastSender", getQualifiedName() + op); 145 return (Broadcast.Sender) op; 146 } 147 148 @Override 149 public Reduce.Receiver getReduceReceiver(final Class<? extends Name<String>> operatorName) { 150 LOG.entering("CommunicationGroupClientImpl", "getReduceReceiver", new Object[]{getQualifiedName(), 151 Utils.simpleName(operatorName)}); 152 final GroupCommOperator op = operators.get(operatorName); 153 if (!(op instanceof Reduce.Receiver)) { 154 throw new RuntimeException("Configured operator is not a reduce receiver"); 155 } 156 commGroupNetworkHandler.addTopologyElement(operatorName); 157 LOG.exiting("CommunicationGroupClientImpl", "getReduceReceiver", getQualifiedName() + op); 158 return (Reduce.Receiver) op; 159 } 160 161 @Override 162 public Scatter.Sender getScatterSender(final Class<? extends Name<String>> operatorName) { 163 LOG.entering("CommunicationGroupClientImpl", "getScatterSender", new Object[]{getQualifiedName(), 164 Utils.simpleName(operatorName)}); 165 final GroupCommOperator op = operators.get(operatorName); 166 if (!(op instanceof Scatter.Sender)) { 167 throw new RuntimeException("Configured operator is not a scatter sender"); 168 } 169 commGroupNetworkHandler.addTopologyElement(operatorName); 170 LOG.exiting("CommunicationGroupClientImpl", "getScatterSender", getQualifiedName() + op); 171 return (Scatter.Sender) op; 172 } 173 174 @Override 175 public Gather.Receiver getGatherReceiver(final Class<? extends Name<String>> operatorName) { 176 LOG.entering("CommunicationGroupClientImpl", "getGatherReceiver", new Object[]{getQualifiedName(), 177 Utils.simpleName(operatorName)}); 178 final GroupCommOperator op = operators.get(operatorName); 179 if (!(op instanceof Gather.Receiver)) { 180 throw new RuntimeException("Configured operator is not a gather receiver"); 181 } 182 commGroupNetworkHandler.addTopologyElement(operatorName); 183 LOG.exiting("CommunicationGroupClientImpl", "getGatherReceiver", getQualifiedName() + op); 184 return (Gather.Receiver) op; 185 } 186 187 188 @Override 189 public Broadcast.Receiver getBroadcastReceiver(final Class<? extends Name<String>> operatorName) { 190 LOG.entering("CommunicationGroupClientImpl", "getBroadcastReceiver", new Object[]{getQualifiedName(), 191 Utils.simpleName(operatorName)}); 192 final GroupCommOperator op = operators.get(operatorName); 193 if (!(op instanceof Broadcast.Receiver)) { 194 throw new RuntimeException("Configured operator is not a broadcast receiver"); 195 } 196 commGroupNetworkHandler.addTopologyElement(operatorName); 197 LOG.exiting("CommunicationGroupClientImpl", "getBroadcastReceiver", getQualifiedName() + op); 198 return (Broadcast.Receiver) op; 199 } 200 201 @Override 202 public Reduce.Sender getReduceSender(final Class<? extends Name<String>> operatorName) { 203 LOG.entering("CommunicationGroupClientImpl", "getReduceSender", new Object[]{getQualifiedName(), 204 Utils.simpleName(operatorName)}); 205 final GroupCommOperator op = operators.get(operatorName); 206 if (!(op instanceof Reduce.Sender)) { 207 throw new RuntimeException("Configured operator is not a reduce sender"); 208 } 209 commGroupNetworkHandler.addTopologyElement(operatorName); 210 LOG.exiting("CommunicationGroupClientImpl", "getReduceSender", getQualifiedName() + op); 211 return (Reduce.Sender) op; 212 } 213 214 @Override 215 public Scatter.Receiver getScatterReceiver(final Class<? extends Name<String>> operatorName) { 216 LOG.entering("CommunicationGroupClientImpl", "getScatterReceiver", new Object[]{getQualifiedName(), 217 Utils.simpleName(operatorName)}); 218 final GroupCommOperator op = operators.get(operatorName); 219 if (!(op instanceof Scatter.Receiver)) { 220 throw new RuntimeException("Configured operator is not a scatter receiver"); 221 } 222 commGroupNetworkHandler.addTopologyElement(operatorName); 223 LOG.exiting("CommunicationGroupClientImpl", "getScatterReceiver", getQualifiedName() + op); 224 return (Scatter.Receiver) op; 225 } 226 227 @Override 228 public Gather.Sender getGatherSender(final Class<? extends Name<String>> operatorName) { 229 LOG.entering("CommunicationGroupClientImpl", "getGatherSender", new Object[]{getQualifiedName(), 230 Utils.simpleName(operatorName)}); 231 final GroupCommOperator op = operators.get(operatorName); 232 if (!(op instanceof Gather.Sender)) { 233 throw new RuntimeException("Configured operator is not a gather sender"); 234 } 235 commGroupNetworkHandler.addTopologyElement(operatorName); 236 LOG.exiting("CommunicationGroupClientImpl", "getGatherSender", getQualifiedName() + op); 237 return (Gather.Sender) op; 238 } 239 240 @Override 241 public void initialize() { 242 LOG.entering("CommunicationGroupClientImpl", "initialize", getQualifiedName()); 243 if (init.compareAndSet(false, true)) { 244 LOG.finest("CommGroup-" + groupName + " is initializing"); 245 final CountDownLatch initLatch = new CountDownLatch(operators.size()); 246 247 final InitHandler initHandler = new InitHandler(initLatch); 248 final EStage<GroupCommOperator> initStage = new ThreadPoolStage<>(initHandler, operators.size()); 249 for (final GroupCommOperator op : operators.values()) { 250 initStage.onNext(op); 251 } 252 253 try { 254 initLatch.await(); 255 } catch (final InterruptedException e) { 256 throw new RuntimeException("InterruptedException while waiting for initialization", e); 257 } 258 259 if (isScatterSender) { 260 updateTopology(); 261 } 262 263 if (initHandler.getException() != null) { 264 throw new RuntimeException(getQualifiedName() + "Parent dead. Current behavior is for the child to die too."); 265 } 266 } 267 LOG.exiting("CommunicationGroupClientImpl", "initialize", getQualifiedName()); 268 } 269 270 @Override 271 public GroupChanges getTopologyChanges() { 272 LOG.entering("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName()); 273 for (final GroupCommOperator op : operators.values()) { 274 final Class<? extends Name<String>> operName = op.getOperName(); 275 LOG.finest("Sending TopologyChanges msg to driver"); 276 try { 277 sender.send(Utils.bldVersionedGCM(groupName, operName, 278 ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, taskId, op.getVersion(), driverId, 279 0, Utils.EMPTY_BYTE_ARR)); 280 } catch (final NetworkException e) { 281 throw new RuntimeException("NetworkException while sending GetTopologyChanges", e); 282 } 283 } 284 final Codec<GroupChanges> changesCodec = new GroupChangesCodec(); 285 final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges = new HashMap<>(); 286 for (final GroupCommOperator op : operators.values()) { 287 final Class<? extends Name<String>> operName = op.getOperName(); 288 final byte[] changes = commGroupNetworkHandler.waitForTopologyChanges(operName); 289 perOpChanges.put(operName, changesCodec.decode(changes)); 290 } 291 final GroupChanges retVal = mergeGroupChanges(perOpChanges); 292 LOG.exiting("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName() + retVal); 293 return retVal; 294 } 295 296 /** 297 * @param perOpChanges 298 * @return 299 */ 300 private GroupChanges mergeGroupChanges(final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges) { 301 LOG.entering("CommunicationGroupClientImpl", "mergeGroupChanges", new Object[]{getQualifiedName(), perOpChanges}); 302 boolean doChangesExist = false; 303 for (final GroupChanges change : perOpChanges.values()) { 304 if (change.exist()) { 305 doChangesExist = true; 306 break; 307 } 308 } 309 final GroupChanges changes = new GroupChangesImpl(doChangesExist); 310 LOG.exiting("CommunicationGroupClientImpl", "mergeGroupChanges", getQualifiedName() + changes); 311 return changes; 312 } 313 314 @Override 315 public void updateTopology() { 316 LOG.entering("CommunicationGroupClientImpl", "updateTopology", getQualifiedName()); 317 for (final GroupCommOperator op : operators.values()) { 318 final Class<? extends Name<String>> operName = op.getOperName(); 319 try { 320 sender.send(Utils.bldVersionedGCM(groupName, operName, 321 ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, taskId, op.getVersion(), driverId, 322 0, Utils.EMPTY_BYTE_ARR)); 323 } catch (final NetworkException e) { 324 throw new RuntimeException("NetworkException while sending UpdateTopology", e); 325 } 326 } 327 for (final GroupCommOperator op : operators.values()) { 328 final Class<? extends Name<String>> operName = op.getOperName(); 329 GroupCommunicationMessage msg; 330 do { 331 msg = commGroupNetworkHandler.waitForTopologyUpdate(operName); 332 } while (!isMsgVersionOk(msg)); 333 334 if (isScatterSender) { 335 updateActiveTasks(msg); 336 } 337 } 338 LOG.exiting("CommunicationGroupClientImpl", "updateTopology", getQualifiedName()); 339 } 340 341 private void updateActiveTasks(final GroupCommunicationMessage msg) { 342 LOG.entering("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg}); 343 344 final Pair<TopologySimpleNode, List<Identifier>> pair = 345 TopologySerializer.decode(msg.getData()[0], identifierFactory); 346 347 topologySimpleNodeRoot = pair.getFirst(); 348 349 activeSlaveTasks = pair.getSecond(); 350 // remove myself 351 activeSlaveTasks.remove(identifierFactory.getNewInstance(taskId)); 352 // sort the tasks in lexicographical order on task ids 353 Collections.sort(activeSlaveTasks, new Comparator<Identifier>() { 354 @Override 355 public int compare(final Identifier o1, final Identifier o2) { 356 return o1.toString().compareTo(o2.toString()); 357 } 358 }); 359 360 LOG.exiting("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg}); 361 } 362 363 private boolean isMsgVersionOk(final GroupCommunicationMessage msg) { 364 LOG.entering("CommunicationGroupClientImpl", "isMsgVersionOk", new Object[]{getQualifiedName(), msg}); 365 if (msg.hasVersion()) { 366 final int msgVersion = msg.getVersion(); 367 final GroupCommOperator operator = operators.get(Utils.getClass(msg.getOperatorname())); 368 final int nodeVersion = operator.getVersion(); 369 final boolean retVal; 370 if (msgVersion < nodeVersion) { 371 LOG.warning(getQualifiedName() + "Received a ver-" + msgVersion + " msg while expecting ver-" + nodeVersion 372 + ". Discarding msg"); 373 retVal = false; 374 } else { 375 retVal = true; 376 } 377 LOG.exiting("CommunicationGroupClientImpl", "isMsgVersionOk", 378 Arrays.toString(new Object[]{retVal, getQualifiedName(), msg})); 379 return retVal; 380 } else { 381 throw new RuntimeException(getQualifiedName() + "can only deal with versioned msgs"); 382 } 383 } 384 385 @Override 386 public List<Identifier> getActiveSlaveTasks() { 387 return this.activeSlaveTasks; 388 } 389 390 @Override 391 public TopologySimpleNode getTopologySimpleNodeRoot() { 392 return this.topologySimpleNodeRoot; 393 } 394 395 private String getQualifiedName() { 396 return Utils.simpleName(groupName) + " "; 397 } 398 399 @Override 400 public Class<? extends Name<String>> getName() { 401 return groupName; 402 } 403 404}