001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.bgd.parameters; 020 021import org.apache.reef.examples.group.bgd.loss.LogisticLossFunction; 022import org.apache.reef.examples.group.bgd.loss.LossFunction; 023import org.apache.reef.examples.group.bgd.loss.SquaredErrorLossFunction; 024import org.apache.reef.examples.group.bgd.loss.WeightedLogisticLossFunction; 025import org.apache.reef.tang.annotations.Parameter; 026 027import javax.inject.Inject; 028import java.util.HashMap; 029import java.util.Map; 030 031/** 032 * Type of loss function used in example. 033 */ 034public class BGDLossType { 035 036 private static final Map<String, Class<? extends LossFunction>> LOSS_FUNCTIONS = 037 new HashMap<String, Class<? extends LossFunction>>() {{ 038 put("logLoss", LogisticLossFunction.class); 039 put("weightedLogLoss", WeightedLogisticLossFunction.class); 040 put("squaredError", SquaredErrorLossFunction.class); 041 }}; 042 043 private final Class<? extends LossFunction> lossFunction; 044 045 private final String lossFunctionStr; 046 047 @Inject 048 public BGDLossType(@Parameter(LossFunctionType.class) final String lossFunctionStr) { 049 this.lossFunctionStr = lossFunctionStr; 050 this.lossFunction = LOSS_FUNCTIONS.get(lossFunctionStr); 051 if (this.lossFunction == null) { 052 throw new RuntimeException("Specified loss function type: " + lossFunctionStr + 053 " is not implemented. Supported types are logLoss|weightedLogLoss|squaredError"); 054 } 055 } 056 057 public Class<? extends LossFunction> getLossFunction() { 058 return this.lossFunction; 059 } 060 061 public String lossFunctionString() { 062 return lossFunctionStr; 063 } 064}