001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.bgd;
020
021import org.apache.reef.annotations.audience.DriverSide;
022import org.apache.reef.driver.context.ActiveContext;
023import org.apache.reef.driver.context.ServiceConfiguration;
024import org.apache.reef.driver.task.CompletedTask;
025import org.apache.reef.driver.task.FailedTask;
026import org.apache.reef.driver.task.RunningTask;
027import org.apache.reef.driver.task.TaskConfiguration;
028import org.apache.reef.evaluator.context.parameters.ContextIdentifier;
029import org.apache.reef.examples.group.bgd.data.parser.Parser;
030import org.apache.reef.examples.group.bgd.data.parser.SVMLightParser;
031import org.apache.reef.examples.group.bgd.loss.LossFunction;
032import org.apache.reef.examples.group.bgd.operatornames.*;
033import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
034import org.apache.reef.examples.group.bgd.parameters.BGDControlParameters;
035import org.apache.reef.examples.group.bgd.parameters.ModelDimensions;
036import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure;
037import org.apache.reef.io.data.loading.api.DataLoadingService;
038import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver;
039import org.apache.reef.io.network.group.api.driver.GroupCommDriver;
040import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
041import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
042import org.apache.reef.io.serialization.Codec;
043import org.apache.reef.io.serialization.SerializableCodec;
044import org.apache.reef.poison.PoisonedConfiguration;
045import org.apache.reef.tang.Configuration;
046import org.apache.reef.tang.Configurations;
047import org.apache.reef.tang.Tang;
048import org.apache.reef.tang.annotations.Unit;
049import org.apache.reef.tang.exceptions.InjectionException;
050import org.apache.reef.tang.formats.ConfigurationSerializer;
051import org.apache.reef.wake.EventHandler;
052
053import javax.inject.Inject;
054import java.util.ArrayList;
055import java.util.HashMap;
056import java.util.List;
057import java.util.Map;
058import java.util.concurrent.atomic.AtomicBoolean;
059import java.util.concurrent.atomic.AtomicInteger;
060import java.util.logging.Level;
061import java.util.logging.Logger;
062
063/**
064 * Driver for BGD example.
065 */
066@DriverSide
067@Unit
068public class BGDDriver {
069
070  private static final Logger LOG = Logger.getLogger(BGDDriver.class.getName());
071
072  private static final Tang TANG = Tang.Factory.getTang();
073
074  private static final double STARTUP_FAILURE_PROB = 0.01;
075
076  private final DataLoadingService dataLoadingService;
077  private final GroupCommDriver groupCommDriver;
078  private final ConfigurationSerializer confSerializer;
079  private final CommunicationGroupDriver communicationsGroup;
080  private final AtomicBoolean masterSubmitted = new AtomicBoolean(false);
081  private final AtomicInteger slaveIds = new AtomicInteger(0);
082  private final Map<String, RunningTask> runningTasks = new HashMap<>();
083  private final AtomicBoolean jobComplete = new AtomicBoolean(false);
084  private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<>();
085  private final BGDControlParameters bgdControlParameters;
086
087  private String communicationsGroupMasterContextId;
088
089  @Inject
090  public BGDDriver(final DataLoadingService dataLoadingService,
091                   final GroupCommDriver groupCommDriver,
092                   final ConfigurationSerializer confSerializer,
093                   final BGDControlParameters bgdControlParameters) {
094    this.dataLoadingService = dataLoadingService;
095    this.groupCommDriver = groupCommDriver;
096    this.confSerializer = confSerializer;
097    this.bgdControlParameters = bgdControlParameters;
098
099    final int minNumOfPartitions =
100        bgdControlParameters.isRampup()
101            ? bgdControlParameters.getMinParts()
102            : dataLoadingService.getNumberOfPartitions();
103
104    final int numParticipants = minNumOfPartitions + 1;
105
106    this.communicationsGroup = this.groupCommDriver.newCommunicationGroup(
107        AllCommunicationGroup.class, // NAME
108        numParticipants);            // Number of participants
109
110    LOG.log(Level.INFO,
111        "Obtained entire communication group: start with {0} partitions", numParticipants);
112
113    this.communicationsGroup
114        .addBroadcast(ControlMessageBroadcaster.class,
115            BroadcastOperatorSpec.newBuilder()
116                .setSenderId(MasterTask.TASK_ID)
117                .setDataCodecClass(SerializableCodec.class)
118                .build())
119        .addBroadcast(ModelBroadcaster.class,
120            BroadcastOperatorSpec.newBuilder()
121                .setSenderId(MasterTask.TASK_ID)
122                .setDataCodecClass(SerializableCodec.class)
123                .build())
124        .addReduce(LossAndGradientReducer.class,
125            ReduceOperatorSpec.newBuilder()
126                .setReceiverId(MasterTask.TASK_ID)
127                .setDataCodecClass(SerializableCodec.class)
128                .setReduceFunctionClass(LossAndGradientReduceFunction.class)
129                .build())
130        .addBroadcast(ModelAndDescentDirectionBroadcaster.class,
131            BroadcastOperatorSpec.newBuilder()
132                .setSenderId(MasterTask.TASK_ID)
133                .setDataCodecClass(SerializableCodec.class)
134                .build())
135        .addBroadcast(DescentDirectionBroadcaster.class,
136            BroadcastOperatorSpec.newBuilder()
137                .setSenderId(MasterTask.TASK_ID)
138                .setDataCodecClass(SerializableCodec.class)
139                .build())
140        .addReduce(LineSearchEvaluationsReducer.class,
141            ReduceOperatorSpec.newBuilder()
142                .setReceiverId(MasterTask.TASK_ID)
143                .setDataCodecClass(SerializableCodec.class)
144                .setReduceFunctionClass(LineSearchReduceFunction.class)
145                .build())
146        .addBroadcast(MinEtaBroadcaster.class,
147            BroadcastOperatorSpec.newBuilder()
148                .setSenderId(MasterTask.TASK_ID)
149                .setDataCodecClass(SerializableCodec.class)
150                .build())
151        .finalise();
152
153    LOG.log(Level.INFO, "Added operators to communicationsGroup");
154  }
155
156  final class ContextActiveHandler implements EventHandler<ActiveContext> {
157
158    @Override
159    public void onNext(final ActiveContext activeContext) {
160      LOG.log(Level.INFO, "Got active context: {0}", activeContext.getId());
161      if (jobRunning(activeContext)) {
162        if (!groupCommDriver.isConfigured(activeContext)) {
163          // The Context is not configured with the group communications service let's do that.
164          submitGroupCommunicationsService(activeContext);
165        } else {
166          // The group communications service is already active on this context. We can submit the task.
167          submitTask(activeContext);
168        }
169      }
170    }
171
172    /**
173     * @param activeContext a context to be configured with group communications.
174     */
175    private void submitGroupCommunicationsService(final ActiveContext activeContext) {
176      final Configuration contextConf = groupCommDriver.getContextConfiguration();
177      final String contextId = getContextId(contextConf);
178      final Configuration serviceConf;
179      if (!dataLoadingService.isDataLoadedContext(activeContext)) {
180        communicationsGroupMasterContextId = contextId;
181        serviceConf = groupCommDriver.getServiceConfiguration();
182      } else {
183        final Configuration parsedDataServiceConf = ServiceConfiguration.CONF
184            .set(ServiceConfiguration.SERVICES, ExampleList.class)
185            .build();
186        serviceConf = Tang.Factory.getTang()
187            .newConfigurationBuilder(groupCommDriver.getServiceConfiguration(), parsedDataServiceConf)
188            .bindImplementation(Parser.class, SVMLightParser.class)
189            .build();
190      }
191
192      LOG.log(Level.FINEST, "Submit GCContext conf: {0} and Service conf: {1}", new Object[]{
193          confSerializer.toString(contextConf), confSerializer.toString(serviceConf)});
194
195      activeContext.submitContextAndService(contextConf, serviceConf);
196    }
197
198    private void submitTask(final ActiveContext activeContext) {
199
200      assert groupCommDriver.isConfigured(activeContext);
201
202      final Configuration partialTaskConfiguration;
203      if (activeContext.getId().equals(communicationsGroupMasterContextId) && !masterTaskSubmitted()) {
204        partialTaskConfiguration = getMasterTaskConfiguration();
205        LOG.info("Submitting MasterTask conf");
206      } else {
207        partialTaskConfiguration = getSlaveTaskConfiguration(getSlaveId(activeContext));
208        // partialTaskConfiguration = Configurations.merge(
209        //     getSlaveTaskConfiguration(getSlaveId(activeContext)),
210        //     getTaskPoisonConfiguration());
211        LOG.info("Submitting SlaveTask conf");
212      }
213      communicationsGroup.addTask(partialTaskConfiguration);
214      final Configuration taskConfiguration = groupCommDriver.getTaskConfiguration(partialTaskConfiguration);
215      LOG.log(Level.FINEST, "{0}", confSerializer.toString(taskConfiguration));
216      activeContext.submitTask(taskConfiguration);
217    }
218
219    private boolean jobRunning(final ActiveContext activeContext) {
220      synchronized (runningTasks) {
221        if (!jobComplete.get()) {
222          return true;
223        } else {
224          LOG.log(Level.INFO, "Job complete. Not submitting any task. Closing context {0}", activeContext);
225          activeContext.close();
226          return false;
227        }
228      }
229    }
230  }
231
232  final class TaskRunningHandler implements EventHandler<RunningTask> {
233
234    @Override
235    public void onNext(final RunningTask runningTask) {
236      synchronized (runningTasks) {
237        if (!jobComplete.get()) {
238          LOG.log(Level.INFO, "Job has not completed yet. Adding to runningTasks: {0}", runningTask);
239          runningTasks.put(runningTask.getId(), runningTask);
240        } else {
241          LOG.log(Level.INFO, "Job complete. Closing context: {0}", runningTask.getActiveContext().getId());
242          runningTask.getActiveContext().close();
243        }
244      }
245    }
246  }
247
248  final class TaskFailedHandler implements EventHandler<FailedTask> {
249
250    @Override
251    public void onNext(final FailedTask failedTask) {
252
253      final String failedTaskId = failedTask.getId();
254
255      LOG.log(Level.WARNING, "Got failed Task: " + failedTaskId);
256
257      if (jobRunning(failedTaskId)) {
258
259        final ActiveContext activeContext = failedTask.getActiveContext().get();
260        final Configuration partialTaskConf = getSlaveTaskConfiguration(failedTaskId);
261
262        // Do not add the task back:
263        // allCommGroup.addTask(partialTaskConf);
264
265        final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
266        LOG.log(Level.FINEST, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
267
268        activeContext.submitTask(taskConf);
269      }
270    }
271
272    private boolean jobRunning(final String failedTaskId) {
273      synchronized (runningTasks) {
274        if (!jobComplete.get()) {
275          return true;
276        } else {
277          final RunningTask rTask = runningTasks.remove(failedTaskId);
278          LOG.log(Level.INFO, "Job has completed. Not resubmitting");
279          if (rTask != null) {
280            LOG.log(Level.INFO, "Closing activecontext");
281            rTask.getActiveContext().close();
282          } else {
283            LOG.log(Level.INFO, "Master must have closed my context");
284          }
285          return false;
286        }
287      }
288    }
289  }
290
291  final class TaskCompletedHandler implements EventHandler<CompletedTask> {
292
293    @Override
294    public void onNext(final CompletedTask task) {
295      LOG.log(Level.INFO, "Got CompletedTask: {0}", task.getId());
296      final byte[] retVal = task.get();
297      if (retVal != null) {
298        final List<Double> losses = BGDDriver.this.lossCodec.decode(retVal);
299        for (final Double loss : losses) {
300          LOG.log(Level.INFO, "OUT: LOSS = {0}", loss);
301        }
302      }
303      synchronized (runningTasks) {
304        LOG.log(Level.INFO, "Acquired lock on runningTasks. Removing {0}", task.getId());
305        final RunningTask rTask = runningTasks.remove(task.getId());
306        if (rTask != null) {
307          LOG.log(Level.INFO, "Closing active context: {0}", task.getActiveContext().getId());
308          task.getActiveContext().close();
309        } else {
310          LOG.log(Level.INFO, "Master must have closed active context already for task {0}", task.getId());
311        }
312
313        if (MasterTask.TASK_ID.equals(task.getId())) {
314          jobComplete.set(true);
315          LOG.log(Level.INFO, "Master(=>Job) complete. Closing other running tasks: {0}", runningTasks.values());
316          for (final RunningTask runTask : runningTasks.values()) {
317            runTask.getActiveContext().close();
318          }
319          LOG.finest("Clearing runningTasks");
320          runningTasks.clear();
321        }
322      }
323    }
324  }
325
326  /**
327   * @return Configuration for the MasterTask
328   */
329  public Configuration getMasterTaskConfiguration() {
330    return Configurations.merge(
331        TaskConfiguration.CONF
332            .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID)
333            .set(TaskConfiguration.TASK, MasterTask.class)
334            .build(),
335        bgdControlParameters.getConfiguration());
336  }
337
338  /**
339   * @return Configuration for the SlaveTask
340   */
341  private Configuration getSlaveTaskConfiguration(final String taskId) {
342    final double pSuccess = bgdControlParameters.getProbOfSuccessfulIteration();
343    final int numberOfPartitions = dataLoadingService.getNumberOfPartitions();
344    final double pFailure = 1 - Math.pow(pSuccess, 1.0 / numberOfPartitions);
345    return Tang.Factory.getTang()
346        .newConfigurationBuilder(
347            TaskConfiguration.CONF
348                .set(TaskConfiguration.IDENTIFIER, taskId)
349                .set(TaskConfiguration.TASK, SlaveTask.class)
350                .build())
351        .bindNamedParameter(ModelDimensions.class, "" + bgdControlParameters.getDimensions())
352        .bindImplementation(LossFunction.class, bgdControlParameters.getLossFunction())
353        .bindNamedParameter(ProbabilityOfFailure.class, Double.toString(pFailure))
354        .build();
355  }
356
357  private Configuration getTaskPoisonConfiguration() {
358    return PoisonedConfiguration.TASK_CONF
359        .set(PoisonedConfiguration.CRASH_PROBABILITY, STARTUP_FAILURE_PROB)
360        .set(PoisonedConfiguration.CRASH_TIMEOUT, 1)
361        .build();
362  }
363
364  private String getContextId(final Configuration contextConf) {
365    try {
366      return TANG.newInjector(contextConf).getNamedInstance(ContextIdentifier.class);
367    } catch (final InjectionException e) {
368      throw new RuntimeException("Unable to inject context identifier from context conf", e);
369    }
370  }
371
372  private String getSlaveId(final ActiveContext activeContext) {
373    return "SlaveTask-" + slaveIds.getAndIncrement();
374  }
375
376  private boolean masterTaskSubmitted() {
377    return !masterSubmitted.compareAndSet(false, true);
378  }
379}