001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.broadcast; 020 021import org.apache.reef.annotations.audience.DriverSide; 022import org.apache.reef.driver.context.ActiveContext; 023import org.apache.reef.driver.context.ClosedContext; 024import org.apache.reef.driver.context.ContextConfiguration; 025import org.apache.reef.driver.evaluator.AllocatedEvaluator; 026import org.apache.reef.driver.evaluator.EvaluatorRequest; 027import org.apache.reef.driver.evaluator.EvaluatorRequestor; 028import org.apache.reef.driver.task.FailedTask; 029import org.apache.reef.driver.task.TaskConfiguration; 030import org.apache.reef.evaluator.context.parameters.ContextIdentifier; 031import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster; 032import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; 033import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; 034import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster; 035import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer; 036import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers; 037import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; 038import org.apache.reef.io.network.group.api.driver.GroupCommDriver; 039import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; 040import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; 041import org.apache.reef.io.serialization.SerializableCodec; 042import org.apache.reef.poison.PoisonedConfiguration; 043import org.apache.reef.tang.Configuration; 044import org.apache.reef.tang.Injector; 045import org.apache.reef.tang.Tang; 046import org.apache.reef.tang.annotations.Parameter; 047import org.apache.reef.tang.annotations.Unit; 048import org.apache.reef.tang.exceptions.InjectionException; 049import org.apache.reef.tang.formats.ConfigurationSerializer; 050import org.apache.reef.wake.EventHandler; 051import org.apache.reef.wake.time.event.StartTime; 052 053import javax.inject.Inject; 054import java.util.concurrent.atomic.AtomicBoolean; 055import java.util.concurrent.atomic.AtomicInteger; 056import java.util.logging.Level; 057import java.util.logging.Logger; 058 059/** 060 * Driver for broadcast example. 061 */ 062@DriverSide 063@Unit 064public class BroadcastDriver { 065 066 private static final Logger LOG = Logger.getLogger(BroadcastDriver.class.getName()); 067 068 private final AtomicBoolean masterSubmitted = new AtomicBoolean(false); 069 private final AtomicInteger slaveIds = new AtomicInteger(0); 070 private final AtomicInteger failureSet = new AtomicInteger(0); 071 072 private final GroupCommDriver groupCommDriver; 073 private final CommunicationGroupDriver allCommGroup; 074 private final ConfigurationSerializer confSerializer; 075 private final int dimensions; 076 private final EvaluatorRequestor requestor; 077 private final int numberOfReceivers; 078 private final AtomicInteger numberOfAllocatedEvaluators; 079 080 private String groupCommConfiguredMasterId; 081 082 @Inject 083 public BroadcastDriver( 084 final EvaluatorRequestor requestor, 085 final GroupCommDriver groupCommDriver, 086 final ConfigurationSerializer confSerializer, 087 @Parameter(ModelDimensions.class) final int dimensions, 088 @Parameter(NumberOfReceivers.class) final int numberOfReceivers) { 089 090 this.requestor = requestor; 091 this.groupCommDriver = groupCommDriver; 092 this.confSerializer = confSerializer; 093 this.dimensions = dimensions; 094 this.numberOfReceivers = numberOfReceivers; 095 this.numberOfAllocatedEvaluators = new AtomicInteger(numberOfReceivers + 1); 096 097 this.allCommGroup = this.groupCommDriver.newCommunicationGroup( 098 AllCommunicationGroup.class, numberOfReceivers + 1); 099 100 LOG.info("Obtained all communication group"); 101 102 this.allCommGroup 103 .addBroadcast(ControlMessageBroadcaster.class, 104 BroadcastOperatorSpec.newBuilder() 105 .setSenderId(MasterTask.TASK_ID) 106 .setDataCodecClass(SerializableCodec.class) 107 .build()) 108 .addBroadcast(ModelBroadcaster.class, 109 BroadcastOperatorSpec.newBuilder() 110 .setSenderId(MasterTask.TASK_ID) 111 .setDataCodecClass(SerializableCodec.class) 112 .build()) 113 .addReduce(ModelReceiveAckReducer.class, 114 ReduceOperatorSpec.newBuilder() 115 .setReceiverId(MasterTask.TASK_ID) 116 .setDataCodecClass(SerializableCodec.class) 117 .setReduceFunctionClass(ModelReceiveAckReduceFunction.class) 118 .build()) 119 .finalise(); 120 121 LOG.info("Added operators to allCommGroup"); 122 } 123 124 /** 125 * Handles the StartTime event: Request numOfReceivers Evaluators. 126 */ 127 final class StartHandler implements EventHandler<StartTime> { 128 @Override 129 public void onNext(final StartTime startTime) { 130 final int numEvals = BroadcastDriver.this.numberOfReceivers + 1; 131 LOG.log(Level.FINE, "Requesting {0} evaluators", numEvals); 132 BroadcastDriver.this.requestor.submit(EvaluatorRequest.newBuilder() 133 .setNumber(numEvals) 134 .setMemory(2048) 135 .build()); 136 } 137 } 138 139 /** 140 * Handles AllocatedEvaluator: Submits a context with an id. 141 */ 142 final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> { 143 @Override 144 public void onNext(final AllocatedEvaluator allocatedEvaluator) { 145 LOG.log(Level.INFO, "Submitting an id context to AllocatedEvaluator: {0}", allocatedEvaluator); 146 final Configuration contextConfiguration = ContextConfiguration.CONF 147 .set(ContextConfiguration.IDENTIFIER, "BroadcastContext-" + 148 BroadcastDriver.this.numberOfAllocatedEvaluators.getAndDecrement()) 149 .build(); 150 allocatedEvaluator.submitContext(contextConfiguration); 151 } 152 } 153 154 /** 155 * FailedTask handler. 156 */ 157 public class FailedTaskHandler implements EventHandler<FailedTask> { 158 159 @Override 160 public void onNext(final FailedTask failedTask) { 161 162 LOG.log(Level.FINE, "Got failed Task: {0}", failedTask.getId()); 163 164 final ActiveContext activeContext = failedTask.getActiveContext().get(); 165 final Configuration partialTaskConf = Tang.Factory.getTang() 166 .newConfigurationBuilder( 167 TaskConfiguration.CONF 168 .set(TaskConfiguration.IDENTIFIER, failedTask.getId()) 169 .set(TaskConfiguration.TASK, SlaveTask.class) 170 .build(), 171 PoisonedConfiguration.TASK_CONF 172 .set(PoisonedConfiguration.CRASH_PROBABILITY, "0") 173 .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") 174 .build()) 175 .bindNamedParameter(ModelDimensions.class, "" + dimensions) 176 .build(); 177 178 // Do not add the task back: 179 // allCommGroup.addTask(partialTaskConf); 180 181 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 182 LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); 183 184 activeContext.submitTask(taskConf); 185 } 186 } 187 188 /** 189 * ActiveContext handler. 190 */ 191 public class ContextActiveHandler implements EventHandler<ActiveContext> { 192 193 private final AtomicBoolean storeMasterId = new AtomicBoolean(false); 194 195 @Override 196 public void onNext(final ActiveContext activeContext) { 197 198 LOG.log(Level.FINE, "Got active context: {0}", activeContext.getId()); 199 200 /** 201 * The active context can be either from data loading service or after network 202 * service has loaded contexts. So check if the GroupCommDriver knows if it was 203 * configured by one of the communication groups. 204 */ 205 if (groupCommDriver.isConfigured(activeContext)) { 206 207 if (activeContext.getId().equals(groupCommConfiguredMasterId) && !masterTaskSubmitted()) { 208 209 final Configuration partialTaskConf = Tang.Factory.getTang() 210 .newConfigurationBuilder( 211 TaskConfiguration.CONF 212 .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID) 213 .set(TaskConfiguration.TASK, MasterTask.class) 214 .build()) 215 .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) 216 .build(); 217 218 allCommGroup.addTask(partialTaskConf); 219 220 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 221 LOG.log(Level.FINER, "Submit MasterTask conf: {0}", confSerializer.toString(taskConf)); 222 223 activeContext.submitTask(taskConf); 224 225 } else { 226 227 final Configuration partialTaskConf = Tang.Factory.getTang() 228 .newConfigurationBuilder( 229 TaskConfiguration.CONF 230 .set(TaskConfiguration.IDENTIFIER, getSlaveId(activeContext)) 231 .set(TaskConfiguration.TASK, SlaveTask.class) 232 .build(), 233 PoisonedConfiguration.TASK_CONF 234 .set(PoisonedConfiguration.CRASH_PROBABILITY, "0.4") 235 .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") 236 .build()) 237 .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) 238 .build(); 239 240 allCommGroup.addTask(partialTaskConf); 241 242 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 243 LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); 244 245 activeContext.submitTask(taskConf); 246 } 247 } else { 248 249 final Configuration contextConf = groupCommDriver.getContextConfiguration(); 250 final String contextId = contextId(contextConf); 251 252 if (storeMasterId.compareAndSet(false, true)) { 253 groupCommConfiguredMasterId = contextId; 254 } 255 256 final Configuration serviceConf = groupCommDriver.getServiceConfiguration(); 257 LOG.log(Level.FINER, "Submit GCContext conf: {0}", confSerializer.toString(contextConf)); 258 LOG.log(Level.FINER, "Submit Service conf: {0}", confSerializer.toString(serviceConf)); 259 260 activeContext.submitContextAndService(contextConf, serviceConf); 261 } 262 } 263 264 private String contextId(final Configuration contextConf) { 265 try { 266 final Injector injector = Tang.Factory.getTang().newInjector(contextConf); 267 return injector.getNamedInstance(ContextIdentifier.class); 268 } catch (final InjectionException e) { 269 throw new RuntimeException("Unable to inject context identifier from context conf", e); 270 } 271 } 272 273 private String getSlaveId(final ActiveContext activeContext) { 274 return "SlaveTask-" + slaveIds.getAndIncrement(); 275 } 276 277 private boolean masterTaskSubmitted() { 278 return !masterSubmitted.compareAndSet(false, true); 279 } 280 } 281 282 /** 283 * ClosedContext handler. 284 */ 285 public class ContextCloseHandler implements EventHandler<ClosedContext> { 286 287 @Override 288 public void onNext(final ClosedContext closedContext) { 289 LOG.log(Level.FINE, "Got closed context: {0}", closedContext.getId()); 290 final ActiveContext parentContext = closedContext.getParentContext(); 291 if (parentContext != null) { 292 LOG.log(Level.FINE, "Closing parent context: {0}", parentContext.getId()); 293 parentContext.close(); 294 } 295 } 296 } 297}