001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.bgd; 020 021import org.apache.reef.examples.group.utils.math.DenseVector; 022import org.apache.reef.examples.group.utils.math.Vector; 023import org.apache.reef.io.network.group.api.operators.Reduce.ReduceFunction; 024import org.apache.reef.io.network.util.Pair; 025 026import javax.inject.Inject; 027 028/** 029 * Loss and gradient reduce function. 030 */ 031public class LossAndGradientReduceFunction 032 implements ReduceFunction<Pair<Pair<Double, Integer>, Vector>> { 033 034 @Inject 035 public LossAndGradientReduceFunction() { 036 } 037 038 @Override 039 public Pair<Pair<Double, Integer>, Vector> apply( 040 final Iterable<Pair<Pair<Double, Integer>, Vector>> lags) { 041 042 double lossSum = 0.0; 043 int numEx = 0; 044 Vector combinedGradient = null; 045 046 for (final Pair<Pair<Double, Integer>, Vector> lag : lags) { 047 if (combinedGradient == null) { 048 combinedGradient = new DenseVector(lag.getSecond()); 049 } else { 050 combinedGradient.add(lag.getSecond()); 051 } 052 lossSum += lag.getFirst().getFirst(); 053 numEx += lag.getFirst().getSecond(); 054 } 055 056 return new Pair<>(new Pair<>(lossSum, numEx), combinedGradient); 057 } 058}