/*
    scb.cc

    Substitution Cipher Breaker
    Copyright (C) 1997,1998 Robert Muth <muth@cs.arizona.edu>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; version 2 of June 1991.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program in form of the file COPYING; 
    if not, write to the 

    Free Software Foundation, Inc. <http://www.fsf.org>
    59 Temple Place, Suite 330, 
    Boston, MA 02111-1307  USA
*/



//=========================================================
//
//  Substitution Cipher Breaker
//
//=========================================================

#include <stdlib.h>
#include <unistd.h>

#include "text.hh"
#include "perm.hh"
#include "stat.hh"
#include "dict.hh"

#define tab "\t"


struct
{
    int DebugFlag;
    int nRounds;
    int Weight;
    int MaxLength;
    int LineLength;
    int Penalty;
}
Global =
{ 0,100,2,9,75,-1000};


//=========================================================
// class score
//=========================================================

class SCORE
{
public:
    PERM perm;    // the permutation
    int  stat1;    // monogram score
    int  stat3;   // trigram score
    int  dict;    // dict score


    void Update(TEXT *t, STAT *s, DICT *d)
    {
        stat1 = s->EvalTextPerm1(t,&perm);
        stat3 = s->EvalTextPerm3(t,&perm);
        dict  = d->EvalTextPerm(t,&perm,Global.Weight,Global.MaxLength);
    }
    
    void Print(ostream& o)
    {
        perm.Print(o);
        o << "Mono:" << tab << stat1 << tab <<
             "Trig:" << tab << stat3 << tab <<
             "Dict:" << dict << endl;
    }
    
};

//=========================================================
// class history
//=========================================================

class HISTORY
{
public:
    TEXT *text;
    STAT *stat;
    DICT *dict;
    
    SCORE Original;
    SCORE AfterStat;
    SCORE AfterDict;

    HISTORY(TEXT *t, STAT *s, DICT *d) : text(t),stat(s),dict(d) {}
            
    void Init(PERM *p)
    {
        Original.perm = *p;
        Original.Update(text,stat,dict);
    }
    
    void ImproveTrigramms()
    {
        int i,j,change,best;

        AfterStat.perm = Original.perm;
        best = Original.stat3;
    
        do
        {
            change = 0;

            for( i=0; i<27; i++ )
                for( j=0; j<27; j++ )
                {
                    AfterStat.perm.Swap(i,j);
                    int cur = stat->EvalTextPerm3(text,&AfterStat.perm);
                    if( cur > best )
                    {
                        best = cur;
                        change = 1;
                        break;
                    }
                    else
                    {
                        AfterStat.perm.Swap(i,j);
                    }
                }
        }while( change );
        AfterStat.Update(text,stat,dict);
    }
    
    void ImproveDictionary()
    {
        int i,j,oldValue,bestValue,curValue,bestSwap;

        AfterDict.perm = AfterStat.perm;
        bestValue = AfterStat.dict;
        
        do
        {
            oldValue = bestValue;
            
            for( i=0; i<26; i++ )
            {
                bestSwap = i;
                
                for( j=i+1; j<27; j++ )
                {
                    if( AfterStat.perm.Translate(i) == 0 ||
                        AfterStat.perm.Translate(j) == 0)
                        continue;   // do not swap ws
                    
                    AfterDict.perm.Swap(i,j);

                    curValue = dict->EvalTextPerm(text,&AfterDict.perm,
                                                  Global.Weight,Global.MaxLength);

                    if( curValue > bestValue )
                    {
                        bestValue = curValue;
                        bestSwap = j;
                    }
                    
                    AfterDict.perm.Swap(i,j);
                }
                
                AfterDict.perm.Swap(i,bestSwap);
            }
        
        }while( oldValue != bestValue );
        AfterDict.Update(text,stat,dict);        
    }

    void Print(ostream& o)
    {
        o << "ORIGINAL  " << tab;
        Original.Print(o);
        text->PrintPerm(o,&Original.perm,Global.LineLength);
        o << "AFTER STAT" << tab;
        AfterStat.Print(o);
        text->PrintPerm(o,&AfterStat.perm,Global.LineLength);
        o << "AFTER DICT" << tab;
        AfterDict.Print(o);
        text->PrintPerm(o,&AfterDict.perm,Global.LineLength);
        o << endl;
    }
    
};

//=========================================================
// usage
//=========================================================

int usage()
{
    cerr << "usage: scb [flags] trigram dictionary ciphertext" << endl
         << "flags" << endl
         << tab << "-d" << tab << "Print Debug Info" << endl
         << tab << "-m maxlength" << tab 
         <<"Set Dictionary Word Length (default: 9)" << endl
         << tab << "-m weight" << tab
         << "Set Dictionary Weight (default: 2)" << endl
         << tab << "-p penalty" << tab
         << "Set Trigram Penalty (default: -1000)" << endl
         << tab << "-l linelength" << tab
         << "Set Line Length (default: 75)" << endl
         << tab << "-r rounds" << tab
         << "Try n times (default: 100)" << endl;
    
    return -1;
}

//=========================================================
// OptimizeFrequency
//=========================================================

void OptimizeFrequency(PERM *p, TEXT *t, STAT *s)
{
    int i,j,change;
    
    int best = s->EvalTextPerm1(t,p);

    do
    {
        change = 0;
        for( i=0; i<27; i++ )
            for( j=0; j<27; j++ )
            {
                p->Swap(i,j);
                int cur = s->EvalTextPerm1(t,p);
                if( cur > best )
                {
                    best = cur;
                    change = 1;
                    break;
                }
                else
                {
                    p->Swap(i,j);
                }
                
            }
    }while( change );
}

//=========================================================
// main
//=========================================================


int main(int argc, char *argv[])
{
    int i;
    
    while( EOF != (i = getopt(argc,argv,"l:p:m:w:r:d")) )
    {
        switch(i)
        {
        case 'd':
            Global.DebugFlag = 1;
            break;
        case 'w':
            Global.Weight = atoi(optarg);
            break;
        case 'm':
            Global.MaxLength = atoi(optarg);
            break;
        case 'l':
            Global.LineLength = atoi(optarg);
            break;
        case 'r':
            Global.nRounds = atoi(optarg);
            break;
        case 'p':
            Global.Penalty = atoi(optarg);
            break;
        case '?':
        case 'h':
        default:
            return usage();
            break;
        }
    }

    argc -= optind;
    argv += optind;

    if( argc < 3 )
        return usage();
    
    //=========================================================
    // Read Files
    //=========================================================
        
    if( Global.DebugFlag )
        cerr << "Reading Trigrams from: " << argv[0] << endl;

    STAT s(argv[0],Global.Penalty);

    if( Global.DebugFlag )
        cerr << "Reading Dictionary from: " << argv[1] << endl;

    DICT d(argv[1],Global.MaxLength);

    if( Global.DebugFlag )
        cerr << "Reading Text from: " << argv[1] << endl;

    TEXT t(argv[2]);

    //=========================================================
    // Search Best Substittution
    //=========================================================

    HISTORY bestTrigram(&t,&s,&d);
    HISTORY bestDictionary(&t,&s,&d);
    HISTORY Current(&t,&s,&d);

    PERM p;
    OptimizeFrequency(&p,&t,&s);

    Current.Init(&p);
    Current.ImproveTrigramms();
    Current.ImproveDictionary();
    bestTrigram = Current;
    bestDictionary = Current;
    
    for( i=0; i<Global.nRounds; i++ )
    {
        PERM q = p;
        q.LocalRandomize();

        Current.Init(&q);
        Current.ImproveTrigramms();
        Current.ImproveDictionary();
        
        if( Current.AfterStat.stat3 > bestTrigram.AfterStat.stat3 )
        {
            bestTrigram = Current;
            
            if( Global.DebugFlag )
            {
                cerr << "ROUND: " << (i+1) << "/" << Global.nRounds << "  TRIG" << endl;
                bestTrigram.Print(cerr);
            }
        }

        if( Current.AfterDict.dict > bestDictionary.AfterDict.dict )
        {
            bestDictionary = Current;
            
            if( Global.DebugFlag )
            {
                cerr << "ROUND: " << (i+1) << "/" << Global.nRounds << "  DICT" << endl;
                bestDictionary.Print(cerr);
            }
        }
    }
    
    cout << "BEST TRIGRAM" << endl;
    bestTrigram.Print(cout);
    
    cout << "BEST DICTIONARY" << endl;
    bestDictionary.Print(cout);
    
    return 0;
}

//=========================================================
// eof
//=========================================================









