This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.bgd;
020
021import org.apache.reef.examples.group.utils.math.DenseVector;
022import org.apache.reef.examples.group.utils.math.Vector;
023import org.apache.reef.io.network.group.api.operators.Reduce.ReduceFunction;
024import org.apache.reef.io.network.util.Pair;
025
026import javax.inject.Inject;
027
028/**
029 * Loss and gradient reduce function.
030 */
031public class LossAndGradientReduceFunction
032    implements ReduceFunction<Pair<Pair<Double, Integer>, Vector>> {
033
034  @Inject
035  public LossAndGradientReduceFunction() {
036  }
037
038  @Override
039  public Pair<Pair<Double, Integer>, Vector> apply(
040      final Iterable<Pair<Pair<Double, Integer>, Vector>> lags) {
041
042    double lossSum = 0.0;
043    int numEx = 0;
044    Vector combinedGradient = null;
045
046    for (final Pair<Pair<Double, Integer>, Vector> lag : lags) {
047      if (combinedGradient == null) {
048        combinedGradient = new DenseVector(lag.getSecond());
049      } else {
050        combinedGradient.add(lag.getSecond());
051      }
052      lossSum += lag.getFirst().getFirst();
053      numEx += lag.getFirst().getSecond();
054    }
055
056    return new Pair<>(new Pair<>(lossSum, numEx), combinedGradient);
057  }
058}