
// g++ -O3 o1rc0.cpp

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#pragma pack(1)

typedef unsigned int   uint;
typedef unsigned short word;
typedef unsigned char  byte;
typedef unsigned long long int qword;

#ifdef __GNUC__
 #define INLINE   __attribute__((always_inline)) inline
 #define NOINLINE __attribute__((noinline))
 #define ALIGN(n) __attribute__((aligned(n)))
 #define __assume_aligned(x,y) (x=decltype(x)(__builtin_assume_aligned((void*)x,y)))
 #define restrict __restrict
 #define __assume(x) (x)
#else
 #define INLINE   __forceinline
 #define NOINLINE __declspec(noinline)
 #define ALIGN(n) __declspec(align(n))
#endif

#if defined(__x86_64) || defined(_M_X64)
 #define X64
#else
 #undef X64
#endif

#define if_e0(x) if(__builtin_expect((x),0))
#define if_e1(x) if(__builtin_expect((x),1))

template <typename T1, typename T2> T1 Min( T1 t1, T2 t2 ) { return t1<t2?t1:t2; }
template <typename T1, typename T2> T1 Max( T1 t1, T2 t2 ) { return t1>t2?t1:t2; }
template <class T,int N> constexpr int DIM( T (&wr)[N] ) { return sizeof(wr)/sizeof(wr[0]); };

uint flen( FILE* f ) {
  fseek( f, 0, SEEK_END );
  uint len = ftell(f);
  fseek( f, 0, SEEK_SET );
  return len;
}

enum {
  SCALElog = 15,
  SCALE    = 1<<SCALElog,
  eSCALE = 16*SCALE,
  hSCALE = SCALE/2,
  mSCALE = SCALE-1
};

static const int M_mwr = (eSCALE/(31+16));
static const int M_mW1 = (31616+0) * (1);
static const int M_mmw = (1+0) * (1);
static const int M_mlim = (0+0) * (1);
static const int M_mW0 = (13109+0) * (1);
static const int M_f0wr = (eSCALE/(72+16));
static const int M_f0mw = (84+1) * (1);
static const int M_f1wr = (eSCALE/(296+16));
static const int M_f1mw = (24+1) * (1);

struct LinearMixer {
  word w;
  LinearMixer( int W=SCALE/2 ) : w(W) {}
  int Mixup( int p1, int p2 ) {
    return p1+(((p2-p1)*w)>>SCALElog);
  }
};

struct Counter {
  word P; 
  void Update( const int c, const int wr, const int Mw ) {
    int dp = P + Mw-SCALE + ((SCALE-2*Mw)&(-c));
    dp = ((dp*wr)>>SCALElog);
    int q = P - dp;
    P = (q<mSCALE ) ? q : mSCALE;
  }
};

//--- #include "coro3b.inc"

#if defined(_MSC_VER) || defined(__clang__)
#pragma runtime_checks( "scu", off )
#pragma check_stack(off)
#pragma strict_gs_check(off)
#endif

#if ((defined __GNUC__) || (defined __INTEL_COMPILER) || (defined __clang__)) && (!defined CORO_NOASM)
  #ifdef X64
//---     #include "coro3_setjmp_x64.h"

struct my_jmpbuf {
  qword rip,rsp;
};

#define ASM __asm__ volatile

INLINE
static int my_setjmp( my_jmpbuf* regs ) {
  qword r;

  ASM ("\
   movq %%rsp,8(%1); \
   call 1f; \
1: popq 0(%1); \
  " : "=a"(r) : "b"(regs),"a"(0) : "%rcx","%rdx","%rsi","%rdi","%rbp","%r8","%r9","%r10","%r11","%r12","%r13","%r14","%r15"
  );

  return r;
}

INLINE
static void my_jmp( my_jmpbuf* regs, int ) {
  ASM ("\
  xchg %0,%%rsp; \
  jmp *%1; \
  " :  : "d"(regs->rsp),"b"( ((byte*)regs->rip)+2 ),"a"(1) : 
  );
}

typedef my_jmpbuf m_jmp_buf[1];
#define jmp_buf m_jmp_buf
#define longjmp my_jmp
#define setjmp  my_setjmp

//    #include "coro3_setjmp_x64b.h"
  #else
//---     #include "coro3_setjmp_x32.h"

struct my_jmpbuf {
  uint eip,esp;
};

#define ASM __asm__ volatile

INLINE
static int my_setjmp( my_jmpbuf* regs ) {
  int r;

  ASM ("\
   movl %%esp,4(%1); \
   call 1f; \
1: popl 0(%1); \
  " : "=a"(r) : "b"(regs),"a"(0) : "%ecx","%edx","%esi","%edi","%ebp"
  );

  return r;
}

INLINE
static void my_jmp( my_jmpbuf* regs, int ) {
  ASM ("\
  xchg %0,%%esp; \
  jmp *%1; \
  " :  : "d"(regs->esp),"b"( ((byte*)regs->eip)+2 ),"a"(1) : 
  );
}

typedef my_jmpbuf m_jmp_buf[1];
#define jmp_buf m_jmp_buf
#define longjmp my_jmp
#define setjmp  my_setjmp

//    #include "coro3_setjmp_x32b.h"
  #endif
#else 
  #ifndef CORO_NOASM
  #define CORO_NOASM 1
  #endif
  #include <setjmp.h>
#endif

struct Coroutine;
static void yield( void* p, int value );

//--- #include "coro3_pin.inc"

//--- #include "coro3_pin_0.inc"

struct coro3_pin_0 {
  byte* ptr;
  byte* beg;
  byte* end;
  uint  f_EOF;
  word  base_offs;
  word  r_code;

  uint getinplen()  { return end-ptr; } //-V110
  uint getinpleft() { return end-ptr; } //-V110
  uint getinpsize() { return ptr-beg; } //-V110

  uint getoutlen()  { return end-ptr; } //-V110
  uint getoutleft() { return end-ptr; } //-V110
  uint getoutsize() { return ptr-beg; } //-V110

  void addinp( byte* inp,uint inplen ) { addbuf(inp,inplen); }
  void addout( byte* out,uint outlen ) { addbuf(out,outlen); }

  void addbuf( byte* buf,uint len ) {
    beg = ptr = buf;
    end = &buf[len];
  }

};

struct coro3_pin: coro3_pin_0 {
  typedef Coroutine wrap;

  void pin_init( wrap* that, uint _r_code ) {
    ptr=beg=end=0; f_EOF=0;
    base_offs = ((char*)this) - ((char*)that);
    r_code = _r_code;
  }

  void yield_r( void ) {
    wrap& W = *(wrap*)(((char*)this) - base_offs );
    yield( (void*)&W, r_code );
  }

  inline uint f_quit( void );

#define coro3_pin_DEFINE_f_quit                     \
  inline uint coro3_pin::f_quit( void ) {           \
    wrap& W = *(wrap*)(((char*)this) - base_offs ); \
    return f_EOF | W.f_quit;                        \
  }                                                 

  void chkinp( void ) { if_e0( ptr>=end ) yield_r(); }

  void chkout( uint d=0 ) { if_e0( ptr>=end-d ) yield_r(); }

  byte get0( void ) { return *ptr++; }
  void put0( uint c ) { *ptr++ = c; }

  uint get( void ) { 
    m0:
    if_e0( ptr>=end ) {
      if_e0( f_quit() ) return uint(-1);
      yield_r();
      goto m0;
    }
    return *ptr++; 
  }

  void put( uint c ) { 
    *ptr++ = c; chkout(); 
  }

};

static uint coro_call0( Coroutine* that );
static void call_do_process0( Coroutine* that );


struct Coroutine {

  union {
    struct {
    byte* inpptr;
    byte* inpbeg;
    byte* inpend;
    uint  inp_f_EOF;
    uint  inp_pad_x64_;

    byte* outptr;
    byte* outbeg;
    byte* outend;
    uint  out_f_EOF;
    uint  out_pad_x64_;
    };
    coro3_pin pin[4];
  };

  volatile uint  state;
  volatile uint  f_quit;

  ALIGN(32) jmp_buf PointA;
  ALIGN(32) jmp_buf PointB;

  ALIGN(8)
  volatile char* stkptrH;
  volatile char* stkptrL; // remembered _sp value for this instance

  typedef void (Coroutine::*t_do_process)( void );
  t_do_process p_do_process;

  enum{ STKPAD=4096*4+24 }; // coroutine stack size
  enum{ STKPAD0=1<<16 }; // stack padding from frontend to coroutine

  ALIGN(8) byte stk[STKPAD];

  void coro_init( void ) {
    f_quit = 0;
    state = 0; 
    for( uint i=0; i<DIM(pin); i++ ) pin[i].pin_init( this, 1+i );
  }

  template <typename T> 
  INLINE
  uint coro_call( T* that ) {
    p_do_process = (t_do_process)&T::do_process;
    return coro_call0(that);
  }

//---------------------

  void chkinp( void ) { pin[0].chkinp(); }
  void chkout( uint d=0 ) { pin[1].chkout(d); }
  uint get( void ) { return pin[0].get(); }
  void put( uint c ) { pin[1].put(c); }

  byte get0( void ) { return pin[0].get0(); }
  void put0( uint c ) { pin[1].put0(c); }

  uint getinplen() { return pin[0].getinplen(); } //-V110
  uint getoutlen() { return pin[1].getoutlen(); } //-V110
  uint getinpleft() { return pin[0].getinpleft(); } //-V524 //-V110
  uint getoutleft() { return pin[1].getoutleft(); } //-V524 //-V110
  uint getinpsize() { return pin[0].getinpsize(); } //-V110
  uint getoutsize() { return pin[1].getoutsize(); } //-V110

  void addinp( byte* inp,uint inplen ) { pin[0].addinp(inp,inplen); }
  void addout( byte* out,uint outlen ) { pin[1].addout(out,outlen); }

};


NOINLINE
static void yield( void* p, int value ) { 
  Coroutine& q = *(Coroutine*)p;
  char curtmp; q.stkptrL=(&curtmp)-16;
  if( setjmp(q.PointB)==0 ) { 
    q.state=value; 
    memcpy( q.stk, (char*)q.stkptrL, q.stkptrH-q.stkptrL );
    longjmp(q.PointA,1); 
    __assume(0);
  }
}

NOINLINE
static uint coro_call0( Coroutine* that ) {
  if_e1( setjmp(that->PointA)==0 ) {
    if_e1( that->state ) { // calls usually take this path, since other runs only on init
      memcpy( (char*)that->stkptrL, that->stk, that->stkptrH-that->stkptrL );
      longjmp(that->PointB,1); 
      __assume(0);
    }
    call_do_process0(that);
    __assume(0);
  }
  return that->state;
}


NOINLINE
static void call_do_process0( Coroutine* that ) {
  // call_do_process0 needs to be an actual separate function to allocate stack pad in its frame
  byte stktmp[Coroutine::STKPAD0]; 
  that->stkptrH = ((char*)stktmp);

  // do_process also needs a separate stack frame, to avoid merging stktmp into it, but ptr call is ok
  (that->*(that->p_do_process))();

  // do_process ends with yield(0) (can't normally return to changed frontend stack)
  // so tell compiler that this point can't be reached
  __assume(0);
}


coro3_pin_DEFINE_f_quit
#undef coro3_pin_DEFINE_f_quit

//#include "coro3_init.inc"


template< int mode >
struct Rangecoder : Coroutine {

  enum {
    NUM   = 4,
    sTOP  = 0x01000000U,
    gTOP  = 0x00010000U,
    Thres = 0xFF000000U,
    Threg = 0x00FF0000U
  };

  uint  range;
  uint  rprec;
  qword lowc;
  uint  code; 
  uint  FFNum;
  uint  Cache;

  void rc_Renorm( void ) {
    if( mode ) {
      while( range<sTOP ) range<<=8, (code<<=8)+=get();
    } else {
      while( range<sTOP ) { range<<=8; ShiftLow(); }
    }
  }

  void rc_BProcess( uint freq, int& b ) { 

    uint rnew = rprec*freq;

    if( mode ) b = (code>=rnew);

    range = ((range-rnew-rnew)&(-b)) + rnew;
    rnew &= -b;

    if( mode ) code -= rnew; else lowc += rnew;

    rc_Renorm();
    rprec = range>>SCALElog;
  }

  void ShiftLow( void ) {
    uint Carry = uint(lowc>>32);
    uint low = uint(lowc);
    if( low<Thres || Carry ) {
      put( Cache+Carry );
      for (;FFNum != 0;FFNum--) put( Carry-1 );
      Cache = low>>24;
      Carry = 0;
    } else FFNum++;
    lowc=(low<<8);
  }

  void rc_Init0( void ) { 
    range = 0xFFFFFFFF;
    rprec = range>>SCALElog;
    lowc  = 0;
    FFNum = 0;
    Cache = 0;
  }
  
  void rc_Init( void ) {
    rc_Init0();
    if( mode==1 ) {
      for(int _=0; _<NUM+1; _++) (code<<=8)+=get(); 
    }
  }

  void rc_Quit( void ) {
    if( mode==0 ) {
      for(int _=0; _<NUM+1; _++) ShiftLow(); 
    }
  }

};


template< int mode >
struct Model : public Rangecoder<mode> {

  uint f_len;

  enum{ CNUM=256 };
  ALIGN(32)
  Counter f0[CNUM];       
  ALIGN(32)
  Counter f1[CNUM][CNUM]; 

  enum {
    inpbufsize = 1<<20,
    outbufsize = 1<<16
  };
  ALIGN(4096)
  byte inpbuf[inpbufsize];
  byte outbuf[outbufsize];

  void do_process( void ) {
    int b,bit,c,i,j,p;

    for( i=0; i<CNUM; i++ ) {
      f0[i].P=hSCALE;
      for( j=0; j<CNUM; j++ ) f1[i][j].P=hSCALE; 
    }

    this->rc_Init();

    LinearMixer mix( M_mW0 );

    p = 0;
    for( i=0; i<f_len; i++ ) {

      if( mode==0 ) c = this->get();

      uint ctx=1;
      for( b=7; b>=0; b-- ) {

        if( mode==0 ) bit = (c>>b)&1;

        int p1 = mix.Mixup(f0[ctx].P,f1[p][ctx].P);

        this->rc_BProcess( p1, bit );

        f0[ctx].Update( bit, M_f0wr, M_f0mw );   
        f1[p][ctx].Update( bit, M_f1wr, M_f1mw ); 

        ctx += ctx + bit;
      }
      c = byte(ctx);

      if( mode==1 ) this->put(c);
      p = c;
    }

    this->rc_Quit();

    yield(this,0);

  }

  void processfile( FILE* f, FILE* g, uint _f_len ) {
    f_len = _f_len;
    this->coro_init();
    this->addout( outbuf, outbufsize );
    while( 1 ) {
      int r = this->coro_call(this); 
      if( r==0 ) break;
      if( r==1 ) {
        uint l = fread( inpbuf, 1, inpbufsize, f );
        if( l==0 ) break; 
        this->addinp( inpbuf, l ); 
      } else if( r==2 ) {
        fwrite( outbuf, 1, outbufsize, g );
        this->addout( outbuf, outbufsize );
      }
    }
    fwrite( outbuf, 1,this->outptr-outbuf, g ); 
  }

};

ALIGN(4096) 
static union {
  Model<0> M0;
  Model<1> M1;
};

int main( int argc, char** argv ) {

  if( argc<4 ) return 1;
  FILE* f = fopen( argv[2], "rb" ); if( f==0 ) return 1;
  FILE* g = fopen( argv[3], "wb" ); if( g==0 ) return 2;
  uint f_len = 0;

  if( argv[1][0]=='c' ) {
    f_len = flen( f );
    fwrite( &f_len, 1,4, g );
    M0.processfile( f, g, f_len );
  } else {
    fread( &f_len, 1,4, f );
    M1.processfile( f, g, f_len );
  }

  fclose( f );
  fclose( g );

  return 0;
}

