import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.PublicKey;
import java.security.InvalidKeyException;
import java.security.spec.InvalidKeySpecException;
import javax.swing.JOptionPane;
import javax.swing.JTextArea;
import javax.swing.JScrollPane;
import java.math.BigInteger;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import java.security.KeyFactory;
import java.security.spec.RSAPublicKeySpec;
import java.io.ByteArrayOutputStream;
import java.io.BufferedReader;
import java.io.FileReader;

public class JavaSecretKeyPublicKeyEncryption {
    
    public JavaSecretKeyPublicKeyEncryption() throws IOException, NoSuchAlgorithmException, NoSuchPaddingException,
            InvalidKeySpecException,BadPaddingException, ClassNotFoundException, InvalidKeyException, IllegalBlockSizeException {
        
        String algorithmName = "AES";
        Cipher cipher = Cipher.getInstance(algorithmName);
        
        // Generate a random secret key
        KeyGenerator secretKeyGen = KeyGenerator.getInstance(algorithmName);
        SecureRandom secRndSecret = new SecureRandom();
        secretKeyGen.init(secRndSecret);
        SecretKey secretKey = secretKeyGen.generateKey();
        
        // This is an alternate method to generate a secret key
        //byte[] key = getKey();
        //SecretKeySpec secretKey = new SecretKeySpec(key,algorithmName);
        //System.out.println("Secret key = " + secretKey);
        
        // Get the public key
        BufferedReader keyIn = new BufferedReader(new FileReader("BillsRSAPublicKey"));
        BigInteger modulus = new BigInteger(keyIn.readLine());
        BigInteger publicExponent = new BigInteger(keyIn.readLine());
        RSAPublicKeySpec pubKeySpec = new RSAPublicKeySpec(modulus,publicExponent);
        PublicKey pubKey = KeyFactory.getInstance("RSA").generatePublic(pubKeySpec);
        //System.out.println(pubKey);
        Cipher keyCipher = Cipher.getInstance("RSA");
        keyCipher.init(Cipher.WRAP_MODE,pubKey);
        byte[] wrappedKey = keyCipher.wrap(secretKey);
        
        while(true) {
            String fileName = JOptionPane.showInputDialog("Enter filename","GenerateRSAPublicPrivateKeyPair.java");
            if( fileName == null ) break;
            
            // Write secret Key encrypted with public key to output file
            FileOutputStream out = new FileOutputStream("Encrypted"+fileName);
            // Write out wrapped key length as hex
            int wrappedLen = wrappedKey.length;
            int len = wrappedLen;
            for( int i = 0; i < 4; i++ ) {
                byte b = (byte)(len&255);
                len = len / 256;
                String lenHex = convertToHex(b);
                out.write((byte)lenHex.charAt(0));
                out.write((byte)lenHex.charAt(1));
            } // end for
            
            // Write out wrapped key as hex
            for( int i = 0; i < wrappedLen; i++ ) {
                String wrappedHex = convertToHex(wrappedKey[i]);
                out.write(wrappedHex.charAt(0));
                out.write(wrappedHex.charAt(1));
            } // end for
            
            // Encrypt file with secret key and write it to output file
            cipher.init(Cipher.ENCRYPT_MODE,secretKey);
            encryptFile(fileName,cipher,out);
            
            String message = getFile(fileName);
            String code = getFile("Encrypted"+fileName);
            
            // Show input file and encrypted file
            String outString = "Message =\n" + message + "\n\nCoded message =\n"+code;
            JTextArea outArea = new JTextArea(outString,30,60);
            JOptionPane.showMessageDialog(null,new JScrollPane(outArea));
        } // end while
    }// end JavaSecretKeyPublicKeyEncryption
    
    public static void main(String[] a) throws IOException, NoSuchAlgorithmException, NoSuchPaddingException,
            InvalidKeySpecException,BadPaddingException, ClassNotFoundException, InvalidKeyException, IllegalBlockSizeException {
        new JavaSecretKeyPublicKeyEncryption();
        System.exit(0);
    } // end main
    
    public void encryptFile(String inFileName, Cipher cipher, FileOutputStream out) throws IOException, NoSuchAlgorithmException, NoSuchPaddingException,
            BadPaddingException, ClassNotFoundException, InvalidKeyException, IllegalBlockSizeException  {
        FileInputStream in = new FileInputStream(inFileName);
        ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream();
        int next;
        String message = "";
        while( (next = in.read()) != -1 ) {
            byteBuffer.write(next);
        } // end while
        in.close();
        byte[] messageBytes = byteBuffer.toByteArray();
        byte[] outBytes = cipher.doFinal(messageBytes);
        int numBytes = 0;
        for( int i = 0; i < outBytes.length; i++ ) {
            String hexCode = convertToHex(outBytes[i]);
            out.write((byte)hexCode.charAt(0));
            out.write((byte)hexCode.charAt(1));
            numBytes++;
            if( numBytes % 40 == 0 ) {
                out.write((byte)'\n');
            } // end if
        } // end for
        out.close();
    } // end encryptFile
    
    String getFile(String fileName) throws IOException {
        FileInputStream in = new FileInputStream(fileName);
        int next;
        StringBuffer message = new StringBuffer();
        while( (next = in.read()) != -1 ) {
            byte nextByte = (byte)next;
            message.append((char)nextByte);
        } // end while
        in.close();
        return message.toString();
    }  // end getMessage
    
    public String convertToHex(byte b) {
        String hex = Integer.toString(b&0xFF,16).toUpperCase();
        if( hex.length() == 1 ) hex = "0" + hex;
        return hex;
    } // end convertToHex
    
    public byte[] getKey() {
        String keySizeString = JOptionPane.showInputDialog("Enter Key Size","128");
        if( keySizeString == null ) return null;
        int keySize = Integer.parseInt(keySizeString);
        int numKeyBytes = keySize/8;
        BigInteger big256 = new BigInteger("256");
        SecureRandom rnd = new SecureRandom();
        BigInteger key = new BigInteger(keySize,rnd);
        System.out.println("BigInteger secret key = " + key);
        byte[] keyBytes = new byte[numKeyBytes];
        for( int i = 0; i < numKeyBytes; i++ ) {
            keyBytes[i] = (byte)(key.mod(big256).intValue());
            key = key.divide(big256);
        } // end for
        return keyBytes;
    } // getKey
    
} // end JavaSecretKeyPublicKeyEncryption
