#lang racket

(require racket/control)
(require (for-syntax racket/control racket/match))

;; Eliminating complex values using delimited control
;; celim-val : Complex-Value ->(Value |- Term) Value
(define (celim-val v)
  (match v
    [(? symbol? x) x]
    [(list 'cons car cdr) (list 'cons (celim-val car) (celim-val cdr))]
    [(list 'split discrim (list 'cons x1 x2) kv)
     (control k
       (list 'split (celim-val discrim)
             (list 'cons x1 x2)
             (% (k (celim-val kv)))))]
    [(list 'nil) (list 'nil)]
    [(list 'split discrim (list 'nil) kv)
     (control k
       (list 'split (celim-val discrim)
             (list 'nil)
             (% (k (celim-val kv)))))]
    [(list (? injection? in) v) (list in (celim-val v))]
    [`(case ,discrim
        [(inl ,xl) ,kvl]
        [(inr ,xr) ,kvr])
     (control k
       `(case ,(celim-val discrim)
          [(inl ,xl) ,(% (k (celim-val kvl)))]
          [(inr ,xr) ,(% (k (celim-val kvr)))]))]
    [`(case ,discrim) (abort `(case ,(celim-val discrim)))]
    [(list 'thunk t) (list 'thunk (celim-term))]
    [(list 'let (? symbol? x) v1 v2)
     (control k
       (list 'let
             x
             (list 'ret (celim-val v1))
             (% (k (celim-val v2)))))]))

;; celim-term : Complex-Term ->(Term |- Term) Term
(define (celim-term t)
  (match t
    [(list 'ret v)
     (list 'ret (celim-val v))]
    [(list 'let x t1 t2)
     (list 'let
           x
           (celim-term t1)
           (% (celim-term t2)))]
    [(list 'split discrim (list 'cons x1 x2) kt)
     (list 'split (celim-val discrim)
           (list 'cons x1 x2)
           (% (celim-term kt)))]
    [(list 'split discrim (list 'nil) kt)
     (list 'split (celim-val discrim)
           (list 'nil)
           (% (celim-term kt)))]
    [`(case ,discrim
        [(inl ,xl) ,ktl]
        [(inr ,xr) ,ktr])
     `(case ,(celim-val discrim)
        [(inl ,xl) ,(% (celim-term ktl))]
        [(inr ,xr) ,(% (celim-term ktr))])]
    [`(case ,discrim) `(case ,(celim-val discrim))]
    [`(force ,v) `(force ,(celim-val v))]
    [`(λ ,x ,t) `(λ ,x ,(% (celim-term t)))]
    [`(app ,t ,v) `(app ,(celim-term t) ,(celim-val v))]
    [`(copat [fst ,t1] [snd ,t2]) `(copat [fst ,(celim-term t1)] [snd ,(celim-term t2)])]
    [(list (? projection? p) t) (list p (celim-term t))]
    [`(copat) `(copat)]))

(define (celim t) (% (celim-term t)))

(define (injection? x) (or (symbol=? x 'inl) (symbol=? x 'inr)))
(define (projection? x) (or (symbol=? x 'fst) (symbol=? x 'snd)))

;;
(module+ test
  (require rackunit)
  (check-equal? (celim '(let x (ret z) (ret (let y x (let q (case y) q)))))
                '(let x (ret z) (let y (ret x) (case y))))
  (check-equal? (celim '(let x (ret z) (ret (let y x (let q (case y [(inl xl) hl] [(inr xr) hr]) q)))))
                '(let x (ret z)
                   (let y (ret x)
                     (case y
                       ((inl xl) (let q (ret hl) (ret q)))
                       ((inr xr) (let q (ret hr) (ret q)))))))
  
  )

;; Problem: only lifts programs to the most recent bound variable, could be better (optimization wise) to lift them as high as possible
;; For instance if we have a (case contradiction) then we want to lift that as high as possible because the whole code is dead!
;; On the other hand, we will get somewhat of a code explosion because of case statements, so in practice we would want a compromise between these two techniques.

;; To solve this we'll work in two passes, first we will annotate
;; every AST node with its free variables

;; Then in the second pass we use a more *cooperative* continuation protocol:

;; First, every prompt is at a binding site and so when we install, it
;; is associated with a list of variables it binds w

;; Second, when we control, we only get up to the closest prompt, so
;; we need the prompt "handler" to "re-raise" the control if the value
;; that we are lifting would still be well typed.

;; If the value would not be well typed, we insert it *here* and the control is "handled"
;;
;; So when we control, we need to provide enough information, this means we return a list of
;;   Exn : FV -> Kont -> CPST
;; where FV are all of the free variables in the term Thk will return
;;   Kont is the captured continuation up to the prompt
;;   CPST is a function that takes a continuation for a term and returns a term (possibly re-raising later)

;; So when we write control, we also provide the free variables of the term
(define (control-vars-fun vars cps)
  (control k2 (list 'raise vars k2 cps)))
(define-syntax-rule (control-vars vars k t)
  (control-vars-fun vars (λ (k) t)))

;; So now all of our work is in the "handler"
;;   First, we need to install a prompt for our term and insert an implicit "return"

(define (minus xs ys)
  (define (keep? x) (member x ys))
  (filter keep? xs))

(define (intersect? xs ys)
  (define (in-ys? x) (member x ys))
  (andmap in-ys? xs))

(define (%-vars-fun binding-vars thnk)
  (define result (% (list 'pure (thnk))))
  (match result
    [(list 'pure tm) tm]
    [(list 'raise t-vars small-k t-wait)
     (if (intersect? binding-vars t-vars)
         (%-vars-fun binding-vars (λ () (t-wait small-k)))
         (control-vars-fun (minus t-vars binding-vars)
                           (λ (big-k)
                             (t-wait (λ (x) (big-k (small-k x)))))))]))

(define-syntax-rule (%-vars binding-vars t)
  (%-vars-fun binding-vars (thunk t)))

(define (get-vars v) first)
(define (celim-val2 v)
  (match-define (list vars val) v)
  (match val
    [(? symbol? x) x]
    [(list 'cons car cdr) (list 'cons (celim-val car) (celim-val cdr))]
    [(list 'split discrim (list 'cons x1 x2) kv)
     (control-vars (get-vars discrim) k
       (list 'split (celim-val discrim)
             (list 'cons x1 x2)
             (% (k (celim-val kv)))))]
    [(list 'nil) (list 'nil)]
    #;
    [(list 'split discrim (list 'nil) kv)
     (control k
       (list 'split (celim-val discrim)
             (list 'nil)
             (% (k (celim-val kv)))))]
    
    [(list (? injection? in) v) (list in (celim-val v))]
    #;
    [`(case ,discrim
        [(inl ,xl) ,kvl]
        [(inr ,xr) ,kvr])
     (control k
       `(case ,(celim-val discrim)
          [(inl ,xl) ,(% (k (celim-val kvl)))]
          [(inr ,xr) ,(% (k (celim-val kvr)))]))]
    #;
    [`(case ,discrim) (abort `(case ,(celim-val discrim)))]
    [(list 'thunk t) (list 'thunk (celim-term))]
    [(list 'let x v1 v2)
     (control-vars (get-vars v1) k
       (list 'let
             x
             (list 'ret (celim-val v1))
             (%-vars (list x) (k (celim-val v2)))))]))

;; celim-term : Complex-Term ->(Term |- Term) Term
(define (celim-term2 t)
  (match-define (list vars tm) t)
  (match tm
    [(list 'ret v)
     (list 'ret (celim-val2 v))]
    [(list 'let x t1 t2)
     (list 'let
           x
           (celim-term2 t1)
           (%-vars (list x) (celim-term2 t2)))]))

(define (celim2 t)
  (let loop ([thk (λ () (list 'pure (celim-term2 t)))])
    (match (% (thk))
      [(list 'pure tm) tm]
      [(list 'raise _ k cps)
       (loop (λ () (cps k)))])))

;; Smart constructors
(define (bound vars t)
  (list (remove* vars (first t)) t))
(define (appl-app ctor . args)
  (define vars (map first args))
  (list (append* vars)
        `(,ctor . ,args)))
(define (var x) `((,x) ,x))
(define (scons car cdr) (appl-app 'cons car cdr))
(define (snil) `(() (nil)))
(define (sinl v) (appl-app 'sinl v))
(define (sthunk t) (appl-app 'thunk t))
(define (sapp t v) (appl-app 'app t v))
(define (sforce v) (appl-app 'force v))
(define (slet x t1 t2)
  (define vars (append (first t1)
                       (bound (list x) t2)))
  (list vars
        `(let ,x ,t1 ,t2)))
(define (sabort v) (appl-app 'case v))
(define (scase discrim x1 t1 x2 t2)
  (define vars
    (append (first discrim)
            (bound (list x1) (first t1))
            (bound (list x2) (first t2))))
  (list vars
        `(case ,discrim [,x1 ,t1] [,x2 ,t2])))
(define (ssplit2 discrim x1 x2 t)
  (define vars
    (append (first discrim)
            (bound (list x1 x2) (first t))))
  (list vars `(split ,discrim (cons ,x1 ,x2) ,t)))
(define (ssplit0 discrim t)
  (define vars
    (append (first discrim) (first t)))
  (list vars `(split ,discrim (nil) ,t)))