Type_Level Programming in Scala

最近在研究的Shapeless框架,需要我们从新认识Scala的类型系统编程,另外Scala的宏编译也是我们需要关注的部分。

动机,混合集合类型(Heterogeneous collection types, HList, Harray)。

1
2
3
val l1 = 42 :: "foo" :: Some(1.0) :: "bar" :: HNil
val i: Int = l1.head
val s: String = l1.tail.head
  • 和Scala的tuples不同
    • 没有大小限制
    • 没有类型抽象限制
  • Requirement: No runtime overhead
    • 没有隐式

总体来说,我们要解决以下几个问题:

  • ADTs (Algebraic Data Types)
  • Different basic values and classes become types
  • Purely-functional design
  • Polymorphic dispatch (on receiver)
    • No match, if…else, etc.

相应的转换规则方法:

1
2
3
4
5
6
ADT Values: `val` 		-> `object`
Members: `def x/val x` -> `type X`
`def f(x)` -> `type F[X]`
`a.b` -> `A#B`
`x: T` -> `X <: T`
`new A(b)` -> `A[B]`

Scala macro precompute

Scala 的宏编译比较特殊,它实际上是在REPL上面进行的一段脚本,在介绍Type Level之前,有必要阐述一下Scala的宏编译实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
object PrintfMacro {

def printf(format: String, params: Any*) = macro printf_impl
def printf_impl(c: blackbox.Context)(format: c.Expr[String], params: c.Expr[Any]*): c.Expr[Unit] = {
import c.universe._

val q"${s_format:String}" = format.tree

val evals = ListBuffer[Tree]()

def precompute(value: Tree, tpe: Type): Tree = {
val freshName = TermName(c.freshName("eval$"))
val valdef = q"val $freshName: $tpe = $value"
evals += valdef
q"$freshName"
}

// 参数
val paramsStack = mutable.Stack[Tree](params map (_.tree): _*)

val refs = s_format.split("(?<=%[\\w%])|(?=%[\\w%])") map {
case "%d" => precompute(paramsStack.pop, typeOf[Int])
case "%s" => precompute(paramsStack.pop, typeOf[String])
case "%%" => Literal(Constant("%"))
case part => Literal(Constant(part))
}

val stats: ListBuffer[Tree] = evals ++ refs.map(ref => q"print($ref)")

c.Expr[Unit](q"..$stats")
}
}
1
2
3
4
5
object PrintfTest {
def main(args: Array[String]) {
printf("simple test age = %d name = %s Hello", 41, "wangzaixiang")
}
}

实际开发通常用到隐式复用中,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
object Macros {
implicit class LoggerEx(val logger: Logger) {
def DEBUG(msg: String): Unit = macro LogMacros.DEBUG1
def DEBUG(msg: String, exception: Exception): Unit = macro LogMacros.DEBUG2
}

object LogMacros {
def DEBUG1(c: Context)(msg: c.Tree): c.Tree = {
import c.universe._
val pre = c.prefix
q"""
val x = $pre.logger
println("=" * 60)
if( x.isDebugEnabled ) x.debug($msg)
println("=" * 60)
"""
}

def DEBUG2(c:Context)(msg: c.Tree, exception: c.Tree): c.Tree = {
import c.universe._
val pre = c.prefix
q"""
val x = $pre.logger
if(x.isDebugEnabled) x.debug( $msg, $exception )
"""
}
}
}

加入测试…

1
2
3
4
5
6
object DebugTest {
def main(args: Array[String]): Unit = {
val log = LoggerFactory.getLogger(this.getClass)
log.DEBUG("Hello World")
}
}

Thinking Recursively: Addition