This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.broadcast;
020
021import org.apache.reef.annotations.audience.DriverSide;
022import org.apache.reef.driver.context.ActiveContext;
023import org.apache.reef.driver.context.ClosedContext;
024import org.apache.reef.driver.context.ContextConfiguration;
025import org.apache.reef.driver.evaluator.AllocatedEvaluator;
026import org.apache.reef.driver.evaluator.EvaluatorRequest;
027import org.apache.reef.driver.evaluator.EvaluatorRequestor;
028import org.apache.reef.driver.task.FailedTask;
029import org.apache.reef.driver.task.TaskConfiguration;
030import org.apache.reef.evaluator.context.parameters.ContextIdentifier;
031import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster;
032import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
033import org.apache.reef.examples.group.bgd.parameters.ModelDimensions;
034import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster;
035import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer;
036import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers;
037import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver;
038import org.apache.reef.io.network.group.api.driver.GroupCommDriver;
039import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
040import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
041import org.apache.reef.io.serialization.SerializableCodec;
042import org.apache.reef.poison.PoisonedConfiguration;
043import org.apache.reef.tang.Configuration;
044import org.apache.reef.tang.Injector;
045import org.apache.reef.tang.Tang;
046import org.apache.reef.tang.annotations.Parameter;
047import org.apache.reef.tang.annotations.Unit;
048import org.apache.reef.tang.exceptions.InjectionException;
049import org.apache.reef.tang.formats.ConfigurationSerializer;
050import org.apache.reef.wake.EventHandler;
051import org.apache.reef.wake.time.event.StartTime;
052
053import javax.inject.Inject;
054import java.util.concurrent.atomic.AtomicBoolean;
055import java.util.concurrent.atomic.AtomicInteger;
056import java.util.logging.Level;
057import java.util.logging.Logger;
058
059/**
060 * Driver for broadcast example.
061 */
062@DriverSide
063@Unit
064public class BroadcastDriver {
065
066  private static final Logger LOG = Logger.getLogger(BroadcastDriver.class.getName());
067
068  private final AtomicBoolean masterSubmitted = new AtomicBoolean(false);
069  private final AtomicInteger slaveIds = new AtomicInteger(0);
070  private final AtomicInteger failureSet = new AtomicInteger(0);
071
072  private final GroupCommDriver groupCommDriver;
073  private final CommunicationGroupDriver allCommGroup;
074  private final ConfigurationSerializer confSerializer;
075  private final int dimensions;
076  private final EvaluatorRequestor requestor;
077  private final int numberOfReceivers;
078  private final AtomicInteger numberOfAllocatedEvaluators;
079
080  private String groupCommConfiguredMasterId;
081
082  @Inject
083  public BroadcastDriver(
084      final EvaluatorRequestor requestor,
085      final GroupCommDriver groupCommDriver,
086      final ConfigurationSerializer confSerializer,
087      @Parameter(ModelDimensions.class) final int dimensions,
088      @Parameter(NumberOfReceivers.class) final int numberOfReceivers) {
089
090    this.requestor = requestor;
091    this.groupCommDriver = groupCommDriver;
092    this.confSerializer = confSerializer;
093    this.dimensions = dimensions;
094    this.numberOfReceivers = numberOfReceivers;
095    this.numberOfAllocatedEvaluators = new AtomicInteger(numberOfReceivers + 1);
096
097    this.allCommGroup = this.groupCommDriver.newCommunicationGroup(
098        AllCommunicationGroup.class, numberOfReceivers + 1);
099
100    LOG.info("Obtained all communication group");
101
102    this.allCommGroup
103        .addBroadcast(ControlMessageBroadcaster.class,
104            BroadcastOperatorSpec.newBuilder()
105                .setSenderId(MasterTask.TASK_ID)
106                .setDataCodecClass(SerializableCodec.class)
107                .build())
108        .addBroadcast(ModelBroadcaster.class,
109            BroadcastOperatorSpec.newBuilder()
110                .setSenderId(MasterTask.TASK_ID)
111                .setDataCodecClass(SerializableCodec.class)
112                .build())
113        .addReduce(ModelReceiveAckReducer.class,
114            ReduceOperatorSpec.newBuilder()
115                .setReceiverId(MasterTask.TASK_ID)
116                .setDataCodecClass(SerializableCodec.class)
117                .setReduceFunctionClass(ModelReceiveAckReduceFunction.class)
118                .build())
119        .finalise();
120
121    LOG.info("Added operators to allCommGroup");
122  }
123
124  /**
125   * Handles the StartTime event: Request numOfReceivers Evaluators.
126   */
127  final class StartHandler implements EventHandler<StartTime> {
128    @Override
129    public void onNext(final StartTime startTime) {
130      final int numEvals = BroadcastDriver.this.numberOfReceivers + 1;
131      LOG.log(Level.FINE, "Requesting {0} evaluators", numEvals);
132      BroadcastDriver.this.requestor.submit(EvaluatorRequest.newBuilder()
133          .setNumber(numEvals)
134          .setMemory(2048)
135          .build());
136    }
137  }
138
139  /**
140   * Handles AllocatedEvaluator: Submits a context with an id.
141   */
142  final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> {
143    @Override
144    public void onNext(final AllocatedEvaluator allocatedEvaluator) {
145      LOG.log(Level.INFO, "Submitting an id context to AllocatedEvaluator: {0}", allocatedEvaluator);
146      final Configuration contextConfiguration = ContextConfiguration.CONF
147          .set(ContextConfiguration.IDENTIFIER, "BroadcastContext-" +
148              BroadcastDriver.this.numberOfAllocatedEvaluators.getAndDecrement())
149          .build();
150      allocatedEvaluator.submitContext(contextConfiguration);
151    }
152  }
153
154  /**
155   * FailedTask handler.
156   */
157  public class FailedTaskHandler implements EventHandler<FailedTask> {
158
159    @Override
160    public void onNext(final FailedTask failedTask) {
161
162      LOG.log(Level.FINE, "Got failed Task: {0}", failedTask.getId());
163
164      final ActiveContext activeContext = failedTask.getActiveContext().get();
165      final Configuration partialTaskConf = Tang.Factory.getTang()
166          .newConfigurationBuilder(
167              TaskConfiguration.CONF
168                  .set(TaskConfiguration.IDENTIFIER, failedTask.getId())
169                  .set(TaskConfiguration.TASK, SlaveTask.class)
170                  .build(),
171              PoisonedConfiguration.TASK_CONF
172                  .set(PoisonedConfiguration.CRASH_PROBABILITY, "0")
173                  .set(PoisonedConfiguration.CRASH_TIMEOUT, "1")
174                  .build())
175          .bindNamedParameter(ModelDimensions.class, "" + dimensions)
176          .build();
177
178      // Do not add the task back:
179      // allCommGroup.addTask(partialTaskConf);
180
181      final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
182      LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
183
184      activeContext.submitTask(taskConf);
185    }
186  }
187
188  /**
189   * ActiveContext handler.
190   */
191  public class ContextActiveHandler implements EventHandler<ActiveContext> {
192
193    private final AtomicBoolean storeMasterId = new AtomicBoolean(false);
194
195    @Override
196    public void onNext(final ActiveContext activeContext) {
197
198      LOG.log(Level.FINE, "Got active context: {0}", activeContext.getId());
199
200      /**
201       * The active context can be either from data loading service or after network
202       * service has loaded contexts. So check if the GroupCommDriver knows if it was
203       * configured by one of the communication groups.
204       */
205      if (groupCommDriver.isConfigured(activeContext)) {
206
207        if (activeContext.getId().equals(groupCommConfiguredMasterId) && !masterTaskSubmitted()) {
208
209          final Configuration partialTaskConf = Tang.Factory.getTang()
210              .newConfigurationBuilder(
211                  TaskConfiguration.CONF
212                      .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID)
213                      .set(TaskConfiguration.TASK, MasterTask.class)
214                      .build())
215              .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions))
216              .build();
217
218          allCommGroup.addTask(partialTaskConf);
219
220          final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
221          LOG.log(Level.FINER, "Submit MasterTask conf: {0}", confSerializer.toString(taskConf));
222
223          activeContext.submitTask(taskConf);
224
225        } else {
226
227          final Configuration partialTaskConf = Tang.Factory.getTang()
228              .newConfigurationBuilder(
229                  TaskConfiguration.CONF
230                      .set(TaskConfiguration.IDENTIFIER, getSlaveId(activeContext))
231                      .set(TaskConfiguration.TASK, SlaveTask.class)
232                      .build(),
233                  PoisonedConfiguration.TASK_CONF
234                      .set(PoisonedConfiguration.CRASH_PROBABILITY, "0.4")
235                      .set(PoisonedConfiguration.CRASH_TIMEOUT, "1")
236                      .build())
237              .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions))
238              .build();
239
240          allCommGroup.addTask(partialTaskConf);
241
242          final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
243          LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
244
245          activeContext.submitTask(taskConf);
246        }
247      } else {
248
249        final Configuration contextConf = groupCommDriver.getContextConfiguration();
250        final String contextId = contextId(contextConf);
251
252        if (storeMasterId.compareAndSet(false, true)) {
253          groupCommConfiguredMasterId = contextId;
254        }
255
256        final Configuration serviceConf = groupCommDriver.getServiceConfiguration();
257        LOG.log(Level.FINER, "Submit GCContext conf: {0}", confSerializer.toString(contextConf));
258        LOG.log(Level.FINER, "Submit Service conf: {0}", confSerializer.toString(serviceConf));
259
260        activeContext.submitContextAndService(contextConf, serviceConf);
261      }
262    }
263
264    private String contextId(final Configuration contextConf) {
265      try {
266        final Injector injector = Tang.Factory.getTang().newInjector(contextConf);
267        return injector.getNamedInstance(ContextIdentifier.class);
268      } catch (final InjectionException e) {
269        throw new RuntimeException("Unable to inject context identifier from context conf", e);
270      }
271    }
272
273    private String getSlaveId(final ActiveContext activeContext) {
274      return "SlaveTask-" + slaveIds.getAndIncrement();
275    }
276
277    private boolean masterTaskSubmitted() {
278      return !masterSubmitted.compareAndSet(false, true);
279    }
280  }
281
282  /**
283   * ClosedContext handler.
284   */
285  public class ContextCloseHandler implements EventHandler<ClosedContext> {
286
287    @Override
288    public void onNext(final ClosedContext closedContext) {
289      LOG.log(Level.FINE, "Got closed context: {0}", closedContext.getId());
290      final ActiveContext parentContext = closedContext.getParentContext();
291      if (parentContext != null) {
292        LOG.log(Level.FINE, "Closing parent context: {0}", parentContext.getId());
293        parentContext.close();
294      }
295    }
296  }
297}