export class Mixture<K> implements Iterable<[K, number]> {
  protected readonly map: ReadonlyMap<K, number>;
  readonly size: number;

  constructor(m: (readonly [K, number])[]) {
    this.map = new Map(m);
    this.size = this.map.size;
  }

  *[Symbol.iterator](): IterableIterator<[K, number]> {
    for (const kv of this.map) {
      yield kv;
    }
  }

  composition(): Composition<K> {
    return new Composition([...this.map]);
  }

  toMap(): Map<K, number> {
    return new Map([...this.map]);
  }

  submixture(keys: Iterable<K> | ((k: K) => boolean)): Mixture<K> {
    if (typeof keys === 'function') {
      keys = [...this.keys()].filter(keys);
    }

    const newValues = new Map<K, number>();
    for (const k of keys) {
      const v = this.map.get(k);
      if (v === undefined) continue;
      newValues.set(k, v);
    }

    return new Mixture([...newValues]);
  }

  total(): number {
    let total = 0;
    for (const v of this.map.values()) {
      total += v;
    }
    return total;
  }

  keys(): IterableIterator<K> {
    return this.map.keys();
  }

  values() {
    return this.map.values();
  }

  entries() {
    return this.map.entries();
  }

  topKeys(n: number): K[] {
    const entries = [...this.map.entries()];
    entries.sort(([, p1], [, p2]) => p2 - p1);
    return entries.map(([k]) => k).slice(0, Math.min(n, entries.length));
  }

  mapKeys<KNew>(f: (k: K) => NonNullable<KNew>): Mixture<NonNullable<KNew>> {
    const keyMap = new Map<K, KNew>();
    for (const k of this.keys()) {
      keyMap.set(k, f(k));
    }

    const newKeySet = new Set(keyMap.values());
    if (newKeySet.size != keyMap.size) {
      throw new Error('key map function did not form bijection');
    }

    return new Mixture(
      [...this].map(([k, v]) => {
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
        return [keyMap.get(k)!, v];
      }),
    );
  }

  add(m: Mixture<K>): Mixture<K> {
    const keyUnion = new Set<K>(this.keys());
    for (const k of m.keys()) {
      keyUnion.add(k);
    }

    const addedValues = new Map<K, number>();
    for (const k of keyUnion) {
      addedValues.set(k, (this.map.get(k) ?? 0.0) + (m.map.get(k) ?? 0.0));
    }

    return new Mixture([...addedValues]);
  }

  static add<TKey>(mixtures: Mixture<TKey>[]): Mixture<TKey> {
    return mixtures.reduce((a, b) => a.add(b));
  }

  get(k: K) {
    return this.map.get(k);
  }
}

type MixtureRecord<K> = K extends string ? Record<K, number> : never;
type MixtureArray<K> = (readonly [K, number])[];
type MixtureData<K> = MixtureRecord<K> | MixtureArray<K>;

export function mixture<K>(values: MixtureData<K>): Mixture<K> {
  if (Array.isArray(values)) {
    return new Mixture(values);
  } else {
    // This can only happen if K is string
    return new Mixture<K>(Object.entries(values) as [K, number][]);
  }
}

export class Composition<K> implements Iterable<[K, number]> {
  protected readonly proportions: ReadonlyMap<K, number>;

  *[Symbol.iterator](): IterableIterator<[K, number]> {
    for (const kv of this.proportions) {
      yield kv;
    }
  }

  constructor(c: [K, number][]) {
    if (c.length === 0) {
      throw new Error('compositions must have at least one component');
    }
    const total = c.reduce((total, [, p]) => total + p, 0);
    this.proportions = new Map(c.map(([k, p]) => [k, p / total]));
  }

  subcomposition(keys: Iterable<K> | ((k: K) => boolean)) {
    if (typeof keys === 'function') {
      keys = [...this.keys()].filter(keys);
    }

    const newProportions = new Map<K, number>();
    for (const k of keys) {
      const p = this.proportions.get(k);
      if (p === undefined) continue;
      newProportions.set(k, p);
    }
    return new Composition([...newProportions]);
  }

  nonzero(): Composition<K> {
    return new Composition([...this].filter(([, v]) => v > 0));
  }

  toMap(): Map<K, number> {
    return new Map([...this.proportions]);
  }

  keys(): IterableIterator<K> {
    return this.proportions.keys();
  }

  values(): IterableIterator<number> {
    return this.proportions.values();
  }

  topKeys(n: number): K[] {
    const entries = [...this.proportions.entries()];
    entries.sort(([, p1], [, p2]) => p2 - p1);
    return entries.map(([k]) => k).slice(0, Math.min(n, entries.length));
  }

  get(k: K) {
    return this.proportions.get(k);
  }
}

export function composition<K>(values: MixtureData<K>): Composition<K> {
  return mixture(values).composition();
}
