001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.broadcast;
020
021import org.apache.reef.annotations.audience.DriverSide;
022import org.apache.reef.driver.context.ActiveContext;
023import org.apache.reef.driver.context.ClosedContext;
024import org.apache.reef.driver.context.ContextConfiguration;
025import org.apache.reef.driver.evaluator.AllocatedEvaluator;
026import org.apache.reef.driver.evaluator.EvaluatorRequestor;
027import org.apache.reef.driver.task.FailedTask;
028import org.apache.reef.driver.task.TaskConfiguration;
029import org.apache.reef.evaluator.context.parameters.ContextIdentifier;
030import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster;
031import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
032import org.apache.reef.examples.group.bgd.parameters.ModelDimensions;
033import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster;
034import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer;
035import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers;
036import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver;
037import org.apache.reef.io.network.group.api.driver.GroupCommDriver;
038import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
039import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
040import org.apache.reef.io.serialization.SerializableCodec;
041import org.apache.reef.poison.PoisonedConfiguration;
042import org.apache.reef.tang.Configuration;
043import org.apache.reef.tang.Injector;
044import org.apache.reef.tang.Tang;
045import org.apache.reef.tang.annotations.Parameter;
046import org.apache.reef.tang.annotations.Unit;
047import org.apache.reef.tang.exceptions.InjectionException;
048import org.apache.reef.tang.formats.ConfigurationSerializer;
049import org.apache.reef.wake.EventHandler;
050import org.apache.reef.wake.time.event.StartTime;
051
052import javax.inject.Inject;
053import java.util.concurrent.atomic.AtomicBoolean;
054import java.util.concurrent.atomic.AtomicInteger;
055import java.util.logging.Level;
056import java.util.logging.Logger;
057
058/**
059 * Driver for broadcast example.
060 */
061@DriverSide
062@Unit
063public class BroadcastDriver {
064
065  private static final Logger LOG = Logger.getLogger(BroadcastDriver.class.getName());
066
067  private final AtomicBoolean masterSubmitted = new AtomicBoolean(false);
068  private final AtomicInteger slaveIds = new AtomicInteger(0);
069  private final AtomicInteger failureSet = new AtomicInteger(0);
070
071  private final GroupCommDriver groupCommDriver;
072  private final CommunicationGroupDriver allCommGroup;
073  private final ConfigurationSerializer confSerializer;
074  private final int dimensions;
075  private final EvaluatorRequestor requestor;
076  private final int numberOfReceivers;
077  private final AtomicInteger numberOfAllocatedEvaluators;
078
079  private String groupCommConfiguredMasterId;
080
081  @Inject
082  public BroadcastDriver(
083      final EvaluatorRequestor requestor,
084      final GroupCommDriver groupCommDriver,
085      final ConfigurationSerializer confSerializer,
086      @Parameter(ModelDimensions.class) final int dimensions,
087      @Parameter(NumberOfReceivers.class) final int numberOfReceivers) {
088
089    this.requestor = requestor;
090    this.groupCommDriver = groupCommDriver;
091    this.confSerializer = confSerializer;
092    this.dimensions = dimensions;
093    this.numberOfReceivers = numberOfReceivers;
094    this.numberOfAllocatedEvaluators = new AtomicInteger(numberOfReceivers + 1);
095
096    this.allCommGroup = this.groupCommDriver.newCommunicationGroup(
097        AllCommunicationGroup.class, numberOfReceivers + 1);
098
099    LOG.info("Obtained all communication group");
100
101    this.allCommGroup
102        .addBroadcast(ControlMessageBroadcaster.class,
103            BroadcastOperatorSpec.newBuilder()
104                .setSenderId(MasterTask.TASK_ID)
105                .setDataCodecClass(SerializableCodec.class)
106                .build())
107        .addBroadcast(ModelBroadcaster.class,
108            BroadcastOperatorSpec.newBuilder()
109                .setSenderId(MasterTask.TASK_ID)
110                .setDataCodecClass(SerializableCodec.class)
111                .build())
112        .addReduce(ModelReceiveAckReducer.class,
113            ReduceOperatorSpec.newBuilder()
114                .setReceiverId(MasterTask.TASK_ID)
115                .setDataCodecClass(SerializableCodec.class)
116                .setReduceFunctionClass(ModelReceiveAckReduceFunction.class)
117                .build())
118        .finalise();
119
120    LOG.info("Added operators to allCommGroup");
121  }
122
123  /**
124   * Handles the StartTime event: Request numOfReceivers Evaluators.
125   */
126  final class StartHandler implements EventHandler<StartTime> {
127    @Override
128    public void onNext(final StartTime startTime) {
129      final int numEvals = BroadcastDriver.this.numberOfReceivers + 1;
130      LOG.log(Level.FINE, "Requesting {0} evaluators", numEvals);
131      BroadcastDriver.this.requestor.newRequest()
132          .setNumber(numEvals)
133          .setMemory(2048)
134          .submit();
135    }
136  }
137
138  /**
139   * Handles AllocatedEvaluator: Submits a context with an id.
140   */
141  final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> {
142    @Override
143    public void onNext(final AllocatedEvaluator allocatedEvaluator) {
144      LOG.log(Level.INFO, "Submitting an id context to AllocatedEvaluator: {0}", allocatedEvaluator);
145      final Configuration contextConfiguration = ContextConfiguration.CONF
146          .set(ContextConfiguration.IDENTIFIER, "BroadcastContext-" +
147              BroadcastDriver.this.numberOfAllocatedEvaluators.getAndDecrement())
148          .build();
149      allocatedEvaluator.submitContext(contextConfiguration);
150    }
151  }
152
153  /**
154   * FailedTask handler.
155   */
156  public class FailedTaskHandler implements EventHandler<FailedTask> {
157
158    @Override
159    public void onNext(final FailedTask failedTask) {
160
161      LOG.log(Level.FINE, "Got failed Task: {0}", failedTask.getId());
162
163      final ActiveContext activeContext = failedTask.getActiveContext().get();
164      final Configuration partialTaskConf = Tang.Factory.getTang()
165          .newConfigurationBuilder(
166              TaskConfiguration.CONF
167                  .set(TaskConfiguration.IDENTIFIER, failedTask.getId())
168                  .set(TaskConfiguration.TASK, SlaveTask.class)
169                  .build(),
170              PoisonedConfiguration.TASK_CONF
171                  .set(PoisonedConfiguration.CRASH_PROBABILITY, "0")
172                  .set(PoisonedConfiguration.CRASH_TIMEOUT, "1")
173                  .build())
174          .bindNamedParameter(ModelDimensions.class, "" + dimensions)
175          .build();
176
177      // Do not add the task back:
178      // allCommGroup.addTask(partialTaskConf);
179
180      final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
181      LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
182
183      activeContext.submitTask(taskConf);
184    }
185  }
186
187  /**
188   * ActiveContext handler.
189   */
190  public class ContextActiveHandler implements EventHandler<ActiveContext> {
191
192    private final AtomicBoolean storeMasterId = new AtomicBoolean(false);
193
194    @Override
195    public void onNext(final ActiveContext activeContext) {
196
197      LOG.log(Level.FINE, "Got active context: {0}", activeContext.getId());
198
199      /**
200       * The active context can be either from data loading service or after network
201       * service has loaded contexts. So check if the GroupCommDriver knows if it was
202       * configured by one of the communication groups.
203       */
204      if (groupCommDriver.isConfigured(activeContext)) {
205
206        if (activeContext.getId().equals(groupCommConfiguredMasterId) && !masterTaskSubmitted()) {
207
208          final Configuration partialTaskConf = Tang.Factory.getTang()
209              .newConfigurationBuilder(
210                  TaskConfiguration.CONF
211                      .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID)
212                      .set(TaskConfiguration.TASK, MasterTask.class)
213                      .build())
214              .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions))
215              .build();
216
217          allCommGroup.addTask(partialTaskConf);
218
219          final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
220          LOG.log(Level.FINER, "Submit MasterTask conf: {0}", confSerializer.toString(taskConf));
221
222          activeContext.submitTask(taskConf);
223
224        } else {
225
226          final Configuration partialTaskConf = Tang.Factory.getTang()
227              .newConfigurationBuilder(
228                  TaskConfiguration.CONF
229                      .set(TaskConfiguration.IDENTIFIER, getSlaveId(activeContext))
230                      .set(TaskConfiguration.TASK, SlaveTask.class)
231                      .build(),
232                  PoisonedConfiguration.TASK_CONF
233                      .set(PoisonedConfiguration.CRASH_PROBABILITY, "0.4")
234                      .set(PoisonedConfiguration.CRASH_TIMEOUT, "1")
235                      .build())
236              .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions))
237              .build();
238
239          allCommGroup.addTask(partialTaskConf);
240
241          final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
242          LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
243
244          activeContext.submitTask(taskConf);
245        }
246      } else {
247
248        final Configuration contextConf = groupCommDriver.getContextConfiguration();
249        final String contextId = contextId(contextConf);
250
251        if (storeMasterId.compareAndSet(false, true)) {
252          groupCommConfiguredMasterId = contextId;
253        }
254
255        final Configuration serviceConf = groupCommDriver.getServiceConfiguration();
256        LOG.log(Level.FINER, "Submit GCContext conf: {0}", confSerializer.toString(contextConf));
257        LOG.log(Level.FINER, "Submit Service conf: {0}", confSerializer.toString(serviceConf));
258
259        activeContext.submitContextAndService(contextConf, serviceConf);
260      }
261    }
262
263    private String contextId(final Configuration contextConf) {
264      try {
265        final Injector injector = Tang.Factory.getTang().newInjector(contextConf);
266        return injector.getNamedInstance(ContextIdentifier.class);
267      } catch (final InjectionException e) {
268        throw new RuntimeException("Unable to inject context identifier from context conf", e);
269      }
270    }
271
272    private String getSlaveId(final ActiveContext activeContext) {
273      return "SlaveTask-" + slaveIds.getAndIncrement();
274    }
275
276    private boolean masterTaskSubmitted() {
277      return !masterSubmitted.compareAndSet(false, true);
278    }
279  }
280
281  /**
282   * ClosedContext handler.
283   */
284  public class ContextCloseHandler implements EventHandler<ClosedContext> {
285
286    @Override
287    public void onNext(final ClosedContext closedContext) {
288      LOG.log(Level.FINE, "Got closed context: {0}", closedContext.getId());
289      final ActiveContext parentContext = closedContext.getParentContext();
290      if (parentContext != null) {
291        LOG.log(Level.FINE, "Closing parent context: {0}", parentContext.getId());
292        parentContext.close();
293      }
294    }
295  }
296}