001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.broadcast; 020 021import org.apache.reef.annotations.audience.DriverSide; 022import org.apache.reef.driver.context.ActiveContext; 023import org.apache.reef.driver.context.ClosedContext; 024import org.apache.reef.driver.context.ContextConfiguration; 025import org.apache.reef.driver.evaluator.AllocatedEvaluator; 026import org.apache.reef.driver.evaluator.EvaluatorRequestor; 027import org.apache.reef.driver.task.FailedTask; 028import org.apache.reef.driver.task.TaskConfiguration; 029import org.apache.reef.evaluator.context.parameters.ContextIdentifier; 030import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster; 031import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; 032import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; 033import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster; 034import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer; 035import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers; 036import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; 037import org.apache.reef.io.network.group.api.driver.GroupCommDriver; 038import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; 039import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; 040import org.apache.reef.io.serialization.SerializableCodec; 041import org.apache.reef.poison.PoisonedConfiguration; 042import org.apache.reef.tang.Configuration; 043import org.apache.reef.tang.Injector; 044import org.apache.reef.tang.Tang; 045import org.apache.reef.tang.annotations.Parameter; 046import org.apache.reef.tang.annotations.Unit; 047import org.apache.reef.tang.exceptions.InjectionException; 048import org.apache.reef.tang.formats.ConfigurationSerializer; 049import org.apache.reef.wake.EventHandler; 050import org.apache.reef.wake.time.event.StartTime; 051 052import javax.inject.Inject; 053import java.util.concurrent.atomic.AtomicBoolean; 054import java.util.concurrent.atomic.AtomicInteger; 055import java.util.logging.Level; 056import java.util.logging.Logger; 057 058/** 059 * Driver for broadcast example. 060 */ 061@DriverSide 062@Unit 063public class BroadcastDriver { 064 065 private static final Logger LOG = Logger.getLogger(BroadcastDriver.class.getName()); 066 067 private final AtomicBoolean masterSubmitted = new AtomicBoolean(false); 068 private final AtomicInteger slaveIds = new AtomicInteger(0); 069 private final AtomicInteger failureSet = new AtomicInteger(0); 070 071 private final GroupCommDriver groupCommDriver; 072 private final CommunicationGroupDriver allCommGroup; 073 private final ConfigurationSerializer confSerializer; 074 private final int dimensions; 075 private final EvaluatorRequestor requestor; 076 private final int numberOfReceivers; 077 private final AtomicInteger numberOfAllocatedEvaluators; 078 079 private String groupCommConfiguredMasterId; 080 081 @Inject 082 public BroadcastDriver( 083 final EvaluatorRequestor requestor, 084 final GroupCommDriver groupCommDriver, 085 final ConfigurationSerializer confSerializer, 086 @Parameter(ModelDimensions.class) final int dimensions, 087 @Parameter(NumberOfReceivers.class) final int numberOfReceivers) { 088 089 this.requestor = requestor; 090 this.groupCommDriver = groupCommDriver; 091 this.confSerializer = confSerializer; 092 this.dimensions = dimensions; 093 this.numberOfReceivers = numberOfReceivers; 094 this.numberOfAllocatedEvaluators = new AtomicInteger(numberOfReceivers + 1); 095 096 this.allCommGroup = this.groupCommDriver.newCommunicationGroup( 097 AllCommunicationGroup.class, numberOfReceivers + 1); 098 099 LOG.info("Obtained all communication group"); 100 101 this.allCommGroup 102 .addBroadcast(ControlMessageBroadcaster.class, 103 BroadcastOperatorSpec.newBuilder() 104 .setSenderId(MasterTask.TASK_ID) 105 .setDataCodecClass(SerializableCodec.class) 106 .build()) 107 .addBroadcast(ModelBroadcaster.class, 108 BroadcastOperatorSpec.newBuilder() 109 .setSenderId(MasterTask.TASK_ID) 110 .setDataCodecClass(SerializableCodec.class) 111 .build()) 112 .addReduce(ModelReceiveAckReducer.class, 113 ReduceOperatorSpec.newBuilder() 114 .setReceiverId(MasterTask.TASK_ID) 115 .setDataCodecClass(SerializableCodec.class) 116 .setReduceFunctionClass(ModelReceiveAckReduceFunction.class) 117 .build()) 118 .finalise(); 119 120 LOG.info("Added operators to allCommGroup"); 121 } 122 123 /** 124 * Handles the StartTime event: Request numOfReceivers Evaluators. 125 */ 126 final class StartHandler implements EventHandler<StartTime> { 127 @Override 128 public void onNext(final StartTime startTime) { 129 final int numEvals = BroadcastDriver.this.numberOfReceivers + 1; 130 LOG.log(Level.FINE, "Requesting {0} evaluators", numEvals); 131 BroadcastDriver.this.requestor.newRequest() 132 .setNumber(numEvals) 133 .setMemory(2048) 134 .submit(); 135 } 136 } 137 138 /** 139 * Handles AllocatedEvaluator: Submits a context with an id. 140 */ 141 final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> { 142 @Override 143 public void onNext(final AllocatedEvaluator allocatedEvaluator) { 144 LOG.log(Level.INFO, "Submitting an id context to AllocatedEvaluator: {0}", allocatedEvaluator); 145 final Configuration contextConfiguration = ContextConfiguration.CONF 146 .set(ContextConfiguration.IDENTIFIER, "BroadcastContext-" + 147 BroadcastDriver.this.numberOfAllocatedEvaluators.getAndDecrement()) 148 .build(); 149 allocatedEvaluator.submitContext(contextConfiguration); 150 } 151 } 152 153 /** 154 * FailedTask handler. 155 */ 156 public class FailedTaskHandler implements EventHandler<FailedTask> { 157 158 @Override 159 public void onNext(final FailedTask failedTask) { 160 161 LOG.log(Level.FINE, "Got failed Task: {0}", failedTask.getId()); 162 163 final ActiveContext activeContext = failedTask.getActiveContext().get(); 164 final Configuration partialTaskConf = Tang.Factory.getTang() 165 .newConfigurationBuilder( 166 TaskConfiguration.CONF 167 .set(TaskConfiguration.IDENTIFIER, failedTask.getId()) 168 .set(TaskConfiguration.TASK, SlaveTask.class) 169 .build(), 170 PoisonedConfiguration.TASK_CONF 171 .set(PoisonedConfiguration.CRASH_PROBABILITY, "0") 172 .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") 173 .build()) 174 .bindNamedParameter(ModelDimensions.class, "" + dimensions) 175 .build(); 176 177 // Do not add the task back: 178 // allCommGroup.addTask(partialTaskConf); 179 180 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 181 LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); 182 183 activeContext.submitTask(taskConf); 184 } 185 } 186 187 /** 188 * ActiveContext handler. 189 */ 190 public class ContextActiveHandler implements EventHandler<ActiveContext> { 191 192 private final AtomicBoolean storeMasterId = new AtomicBoolean(false); 193 194 @Override 195 public void onNext(final ActiveContext activeContext) { 196 197 LOG.log(Level.FINE, "Got active context: {0}", activeContext.getId()); 198 199 /** 200 * The active context can be either from data loading service or after network 201 * service has loaded contexts. So check if the GroupCommDriver knows if it was 202 * configured by one of the communication groups. 203 */ 204 if (groupCommDriver.isConfigured(activeContext)) { 205 206 if (activeContext.getId().equals(groupCommConfiguredMasterId) && !masterTaskSubmitted()) { 207 208 final Configuration partialTaskConf = Tang.Factory.getTang() 209 .newConfigurationBuilder( 210 TaskConfiguration.CONF 211 .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID) 212 .set(TaskConfiguration.TASK, MasterTask.class) 213 .build()) 214 .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) 215 .build(); 216 217 allCommGroup.addTask(partialTaskConf); 218 219 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 220 LOG.log(Level.FINER, "Submit MasterTask conf: {0}", confSerializer.toString(taskConf)); 221 222 activeContext.submitTask(taskConf); 223 224 } else { 225 226 final Configuration partialTaskConf = Tang.Factory.getTang() 227 .newConfigurationBuilder( 228 TaskConfiguration.CONF 229 .set(TaskConfiguration.IDENTIFIER, getSlaveId(activeContext)) 230 .set(TaskConfiguration.TASK, SlaveTask.class) 231 .build(), 232 PoisonedConfiguration.TASK_CONF 233 .set(PoisonedConfiguration.CRASH_PROBABILITY, "0.4") 234 .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") 235 .build()) 236 .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) 237 .build(); 238 239 allCommGroup.addTask(partialTaskConf); 240 241 final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); 242 LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); 243 244 activeContext.submitTask(taskConf); 245 } 246 } else { 247 248 final Configuration contextConf = groupCommDriver.getContextConfiguration(); 249 final String contextId = contextId(contextConf); 250 251 if (storeMasterId.compareAndSet(false, true)) { 252 groupCommConfiguredMasterId = contextId; 253 } 254 255 final Configuration serviceConf = groupCommDriver.getServiceConfiguration(); 256 LOG.log(Level.FINER, "Submit GCContext conf: {0}", confSerializer.toString(contextConf)); 257 LOG.log(Level.FINER, "Submit Service conf: {0}", confSerializer.toString(serviceConf)); 258 259 activeContext.submitContextAndService(contextConf, serviceConf); 260 } 261 } 262 263 private String contextId(final Configuration contextConf) { 264 try { 265 final Injector injector = Tang.Factory.getTang().newInjector(contextConf); 266 return injector.getNamedInstance(ContextIdentifier.class); 267 } catch (final InjectionException e) { 268 throw new RuntimeException("Unable to inject context identifier from context conf", e); 269 } 270 } 271 272 private String getSlaveId(final ActiveContext activeContext) { 273 return "SlaveTask-" + slaveIds.getAndIncrement(); 274 } 275 276 private boolean masterTaskSubmitted() { 277 return !masterSubmitted.compareAndSet(false, true); 278 } 279 } 280 281 /** 282 * ClosedContext handler. 283 */ 284 public class ContextCloseHandler implements EventHandler<ClosedContext> { 285 286 @Override 287 public void onNext(final ClosedContext closedContext) { 288 LOG.log(Level.FINE, "Got closed context: {0}", closedContext.getId()); 289 final ActiveContext parentContext = closedContext.getParentContext(); 290 if (parentContext != null) { 291 LOG.log(Level.FINE, "Closing parent context: {0}", parentContext.getId()); 292 parentContext.close(); 293 } 294 } 295 } 296}