
#include <setjmp.h>             // for setjmp(), longjmp(), and jmp_buf
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>

typedef unsigned int   uint;
typedef unsigned short word;
typedef unsigned char  byte;

#define MAXBITS 15 // maximum bits in a code

struct state {
  byte *out;
  uint outlen;
  uint outcnt;
  byte *in;
  uint inlen;
  uint incnt;
  int bitbuf;
  int bitcnt;
  jmp_buf env;
};

int bits( state *s, int need ) {
  int val;
  val = s->bitbuf;
  while (s->bitcnt < need) {
    if (s->incnt == s->inlen) longjmp(s->env, 1);
    val |= (int)(s->in[s->incnt++]) << s->bitcnt;
    s->bitcnt += 8;
  }
  s->bitbuf = (int)(val >> need);
  s->bitcnt -= need;
  return (int)(val & ((1L << need) - 1));
}

int stored( state *s) {
  uint len;
  s->bitbuf = 0;
  s->bitcnt = 0;
  if (s->incnt + 4 > s->inlen) return 2;
  len = s->in[s->incnt++];
  len |= s->in[s->incnt++] << 8;
  if (s->in[s->incnt++] != (~len & 0xff) ||
    s->in[s->incnt++] != ((~len >> 8) & 0xff))
    return -2;
  if (s->incnt + len > s->inlen) return 2;
  if (s->out != ((byte *)0)) {
    if (s->outcnt + len > s->outlen)
      return 1;
    while (len--)
      s->out[s->outcnt++] = s->in[s->incnt++];
  }
  else {
    s->outcnt += len;
    s->incnt += len;
  }
  return 0;
}

struct huffman {
  short *count;
  short *symbol;
};

int decode( state *s, huffman *h ) {
  int len;
  int code;
  int first;
  int count;
  int index;
  code = first = index = 0;
  for( len = 1; len <= MAXBITS; len++ ) {
    code |= bits(s, 1);
    count = h->count[len];
    if( code < first + count ) return h->symbol[index + (code - first)];
    index += count;
    first += count;
    first <<= 1;
    code <<= 1;
  }
  return -9;
}

int construct( huffman *h, short *length, int n ) {
  int symbol;
  int len;
  int left;
  short offs[15+1];

  for (len = 0; len <= 15; len++)
    h->count[len] = 0;
  for (symbol = 0; symbol < n; symbol++)
    (h->count[length[symbol]])++;
  if (h->count[0] == n)
    return 0;

  left = 1;
  for (len = 1; len <= 15; len++) {
    left <<= 1;
    left -= h->count[len];
    if (left < 0) return left;
  }

  offs[1] = 0;
  for (len = 1; len < 15; len++)
    offs[len + 1] = offs[len] + h->count[len];

  for (symbol = 0; symbol < n; symbol++)
    if (length[symbol] != 0)
      h->symbol[offs[length[symbol]]++] = symbol;

  return left;
}

int codes( state *s, huffman *lencode, huffman *distcode ) {
  int symbol;
  int len;
  uint dist;
  const short lens[29] = {
    3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
    35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
  const short lext[29] = {
    0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
    3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
  const short dists[30] = {
    1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
    257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
    8193, 12289, 16385, 24577};
  const short dext[30] = {
    0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
    7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
    12, 12, 13, 13};

  do {
    symbol = decode(s, lencode);
    if (symbol < 0) return symbol;
    if (symbol < 256) {

      if (s->out != ((byte *)0)) {
        if (s->outcnt == s->outlen) return 1;
        s->out[s->outcnt] = symbol;
      }
      s->outcnt++;
    }
    else if (symbol > 256) {

      symbol -= 257;
      if (symbol >= 29) return -9;
      len = lens[symbol] + bits(s, lext[symbol]);

      symbol = decode(s, distcode);
      if (symbol < 0) return symbol;
      dist = dists[symbol] + bits(s, dext[symbol]);
      if (dist > s->outcnt)
        return -10;

      if (s->out != ((byte *)0)) {
        if (s->outcnt + len > s->outlen) return 1;
        while (len--) {
          s->out[s->outcnt] = s->out[s->outcnt - dist];
          s->outcnt++;
        }
      }
      else
        s->outcnt += len;
    }
  } while (symbol != 256);

  return 0;
}

int fixed( state *s ) {
  int virgin = 1;
  short lencnt[15+1], lensym[288];
  short distcnt[15+1], distsym[30];
  huffman lencode = {lencnt, lensym};
  huffman distcode = {distcnt, distsym};

  if (virgin) {
    int symbol;
    short lengths[288];

    for (symbol = 0; symbol < 144; symbol++)
      lengths[symbol] = 8;
    for (; symbol < 256; symbol++)
      lengths[symbol] = 9;
    for (; symbol < 280; symbol++)
      lengths[symbol] = 7;
    for (; symbol < 288; symbol++)
      lengths[symbol] = 8;
    construct(&lencode, lengths, 288);

    for (symbol = 0; symbol < 30; symbol++)
      lengths[symbol] = 5;
    construct(&distcode, lengths, 30);

    virgin = 0;
  }

  return codes(s, &lencode, &distcode);
}

int dynamic( state *s ) {
  int nlen, ndist, ncode;
  int index;
  int err;
  short lengths[(286+30)];
  short lencnt[15+1], lensym[286];
  short distcnt[15+1], distsym[30];
  huffman lencode = {lencnt, lensym};
  huffman distcode = {distcnt, distsym};
  const short order[19] =
    {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};

  nlen = bits(s, 5) + 257;
  ndist = bits(s, 5) + 1;
  ncode = bits(s, 4) + 4;
  if (nlen > 286 || ndist > 30)
    return -3;

  for (index = 0; index < ncode; index++)
    lengths[order[index]] = bits(s, 3);
  for (; index < 19; index++)
    lengths[order[index]] = 0;

  err = construct(&lencode, lengths, 19);
  if (err != 0) return -4;

  index = 0;
  while (index < nlen + ndist) {
    int symbol;
    int len;
    symbol = decode(s, &lencode);
    if (symbol < 16)
      lengths[index++] = symbol;
    else {
      len = 0;
      if (symbol == 16) {
        if (index == 0) return -5;
        len = lengths[index - 1];
        symbol = 3 + bits(s, 2);
      }
      else if (symbol == 17)
        symbol = 3 + bits(s, 3);
      else
        symbol = 11 + bits(s, 7);
      if (index + symbol > nlen + ndist)
        return -6;
      while (symbol--)
        lengths[index++] = len;
    }
  }

  err = construct(&lencode, lengths, nlen);
  if (err < 0 || (err > 0 && nlen - lencode.count[0] != 1))
    return -7;

  err = construct(&distcode, lengths + nlen, ndist);
  if (err < 0 || (err > 0 && ndist - distcode.count[0] != 1))
    return -8;

  return codes(s, &lencode, &distcode);
}

int puff( byte *dest, uint *destlen, byte *source, uint *sourcelen ) {
  state s;
  int last, type;
  int err;

  s.out = dest;
  s.outlen = *destlen;
  s.outcnt = 0;

  s.in = source;
  s.inlen = *sourcelen;
  s.incnt = 0;
  s.bitbuf = 0;
  s.bitcnt = 0;

  if (setjmp(s.env) != 0)
    err = 2;
  else {

    do {
      last = bits(&s, 1);
      type = bits(&s, 2);
      err = type == 0 ? stored(&s) :
          (type == 1 ? fixed(&s) :
           (type == 2 ? dynamic(&s) :
          -1));
      if (err != 0) break;
    } while (!last);
  }

  if (err <= 0) {
    *destlen = s.outcnt;
    *sourcelen = s.incnt;
  }
  return err;
}

byte *yank( char *name, uint *len ) {
  uint size;
  byte *buf;
  FILE *in;
  struct stat s;
  *len = 0;
  if (stat(name, &s)) return NULL;
  if ((s.st_mode & S_IFMT) != S_IFREG) return NULL;
  size = (uint)(s.st_size);
  if (size == 0 || (off_t)size != s.st_size) return NULL;
  in = fopen(name, "r");
  if (in == NULL) return NULL;
  buf = (byte*)malloc(size);
  if (buf != NULL && fread(buf, 1, size, in) != size) {
    free(buf);
    buf = NULL;
  }
  fclose(in);
  *len = size;
  return buf;
}

int main( int argc, char **argv ) {
  int ret;
  byte *source;
  uint len, sourcelen, destlen;
  if (argc < 2) return 2;
  source = yank(argv[1], &len);
  if (source == NULL) return 2;
  sourcelen = len;
  ret = puff(((byte *)0), &destlen, source, &sourcelen );
  if (ret)
    printf("puff() failed with return code %d\n", ret);
  else {
    printf("puff() succeeded uncompressing %lu bytes\n", destlen);
    if (sourcelen < len) printf("%lu compressed bytes unused\n", len - sourcelen );
  }
  free(source);
  return ret;
}
