def multiply(factor: Int)(numberToBeMultiplied: Int): Int = factor * numberToBeMultiplied
val multiplyBy3 = multiply(3)_ // resulting function signature Int => Int
val multiplyBy10 = multiply(10)_ // resulting function signature Int => Int
val sixFromCurriedCall = multiplyBy3(2) //6
val sixFromFullCall = multiply(3)(2) //6
val fortyFromCurriedCall = multiplyBy10(4) //40
val fortyFromFullCall = multiply(10)(4) //40
def numberOrCharacterSwitch(toggleNumber: Boolean)(number: Int)(character: Char): String =
if (toggleNumber) number.toString else character.toString
// need to explicitly specify the type of the parameter to be curried
// resulting function signature Boolean => String
val switchBetween3AndE = numberOrCharacterSwitch(_: Boolean)(3)('E')
switchBetween3AndE(true) // "3"
switchBetween3AndE(false) // "E"
def minus(left: Int, right: Int) = left - right
val numberMinus5 = minus(_: Int, 5)
val fiveMinusNumber = minus(5, _: Int)
numberMinus5(7) // 2
fiveMinusNumber(7) // -2
Let's define a function of 2 arguments:
def add: (Int, Int) => Int = (x,y) => x + y val three = add(1,2)
Currying add
transforms it into a function that takes one Int
and returns a function (from one Int
to an Int
)
val addCurried: (Int) => (Int => Int) = add2.curried // ^~~ take *one* Int // ^~~~ return a *function* from Int to Int val add1: Int => Int = addCurried(1) val three: Int = add1(2) val allInOneGo: Int = addCurried(1)(2)
You can apply this concept to any function that takes multiple arguments. Currying a function that takes multiple arguments, transforms it into a series of applications of functions that take one argument:
def add3: (Int, Int, Int) => Int = (a,b,c) => a + b + c + d def add3Curr: Int => (Int => (Int => Int)) = add3.curried val x = add3Curr(1)(2)(42)
Currying, according to Wikipedia,
is the technique of translating the evaluation of a function that takes multiple arguments into evaluating a sequence of functions.
Concretely, in terms of scala types, in the context of a function that take two arguments, (has arity 2) it is the conversion of
val f: (A, B) => C // a function that takes two arguments of type `A` and `B` respectively
// and returns a value of type `C`
to
val curriedF: A => B => C // a function that take an argument of type `A`
// and returns *a function*
// that takes an argument of type `B` and returns a `C`
So for arity-2 functions we can write the curry function as:
def curry[A, B, C](f: (A, B) => C): A => B => C = {
(a: A) => (b: B) => f(a, b)
}
Usage:
val f: (String, Int) => Double = {(_, _) => 1.0}
val curriedF: String => Int => Double = curry(f)
f("a", 1) // => 1.0
curriedF("a")(1) // => 1.0
Scala gives us a few language features that help with this:
curriedF
can be written as:def curriedFAsAMethod(str: String)(int: Int): Double = 1.0
val curriedF = curriedFAsAMethod _
A => B => C
to (A, B) => C
) using a standard library method: Function.uncurried
val f: (String, Int) => Double = Function.uncurried(curriedF)
f("a", 1) // => 1.0
Currying is the technique of translating the evaluation of a function that takes multiple arguments into evaluating a sequence of functions, each with a single argument.
This is normally useful when for example:
Example 1
Let's assume that the total yearly income is a function composed by the income and a bonus:
val totalYearlyIncome:(Int,Int) => Int = (income, bonus) => income + bonus
The curried version of the above 2-arity function is:
val totalYearlyIncomeCurried: Int => Int => Int = totalYearlyIncome.curried
Note in the above definition that the type can be also viewed/written as:
Int => (Int => Int)
Let's assume that the yearly income portion is known in advance:
val partialTotalYearlyIncome: Int => Int = totalYearlyIncomeCurried(10000)
And at some point down the line the bonus is known:
partialTotalYearlyIncome(100)
Example 2
Let's assume that the car manufacturing involves the application of car wheels and car body:
val carManufacturing:(String,String) => String = (wheels, body) => wheels + body
These parts are applied by different factories:
class CarWheelsFactory {
def applyCarWheels(carManufacturing:(String,String) => String): String => String =
carManufacturing.curried("applied wheels..")
}
class CarBodyFactory {
def applyCarBody(partialCarWithWheels: String => String): String = partialCarWithWheels("applied car body..")
}
Notice that the CarWheelsFactory
above curries the car manufacturing function and only applies the wheels.
The car manufacturing process then will take the below form:
val carWheelsFactory = new CarWheelsFactory()
val carBodyFactory = new CarBodyFactory()
val carManufacturing:(String,String) => String = (wheels, body) => wheels + body
val partialCarWheelsApplied: String => String = carWheelsFactory.applyCarWheels(carManufacturing)
val carCompleted = carBodyFactory.applyCarBody(partialCarWheelsApplied)
What we have is a list of credit cards and we'd like to calculate the premiums for all those cards that the credit card company has to pay out. The premiums themselves depend on the total number of credit cards, so that the company adjust them accordingly.
We already have a function that calculates the premium for a single credit card and takes into account the total cards the company has issued:
case class CreditCard(creditInfo: CreditCardInfo, issuer: Person, account: Account)
object CreditCard {
def getPremium(totalCards: Int, creditCard: CreditCard): Double = { ... }
}
Now a reasonable approach to this problem would be to map each credit card to a premium and reduce it to a sum. Something like this:
val creditCards: List[CreditCard] = getCreditCards()
val allPremiums = creditCards.map(CreditCard.getPremium).sum //type mismatch; found : (Int, CreditCard) ⇒ Double required: CreditCard ⇒ ?
However the compiler isn't going to like this, because CreditCard.getPremium
requires two parameters.
Partial application to the rescue! We can partially apply the total number of credit cards and use that function to map the credit cards to their premiums.
All we need to do is curry the getPremium
function by changing it to use multiple parameter lists and we're good to go.
The result should look something like this:
object CreditCard {
def getPremium(totalCards: Int)(creditCard: CreditCard): Double = { ... }
}
val creditCards: List[CreditCard] = getCreditCards()
val getPremiumWithTotal = CreditCard.getPremium(creditCards.length)_
val allPremiums = creditCards.map(getPremiumWithTotal).sum