001    /****************************************************************
002     * Licensed to the Apache Software Foundation (ASF) under one   *
003     * or more contributor license agreements.  See the NOTICE file *
004     * distributed with this work for additional information        *
005     * regarding copyright ownership.  The ASF licenses this file   *
006     * to you under the Apache License, Version 2.0 (the            *
007     * "License"); you may not use this file except in compliance   *
008     * with the License.  You may obtain a copy of the License at   *
009     *                                                              *
010     *   http://www.apache.org/licenses/LICENSE-2.0                 *
011     *                                                              *
012     * Unless required by applicable law or agreed to in writing,   *
013     * software distributed under the License is distributed on an  *
014     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY       *
015     * KIND, either express or implied.  See the License for the    *
016     * specific language governing permissions and limitations      *
017     * under the License.                                           *
018     ****************************************************************/
019    
020    package org.apache.james.mime4j.codec;
021    
022    import java.io.IOException;
023    import java.io.InputStream;
024    
025    import org.apache.commons.logging.Log;
026    import org.apache.commons.logging.LogFactory;
027    
028    /**
029     * Performs Base-64 decoding on an underlying stream.
030     */
031    public class Base64InputStream extends InputStream {
032        private static Log log = LogFactory.getLog(Base64InputStream.class);
033    
034        private static final int ENCODED_BUFFER_SIZE = 1536;
035    
036        private static final int[] BASE64_DECODE = new int[256];
037    
038        static {
039            for (int i = 0; i < 256; i++)
040                BASE64_DECODE[i] = -1;
041            for (int i = 0; i < Base64OutputStream.BASE64_TABLE.length; i++)
042                BASE64_DECODE[Base64OutputStream.BASE64_TABLE[i] & 0xff] = i;
043        }
044    
045        private static final byte BASE64_PAD = '=';
046    
047        private static final int EOF = -1;
048    
049        private final byte[] singleByte = new byte[1];
050    
051        private boolean strict;
052    
053        private final InputStream in;
054        private boolean closed = false;
055    
056        private final byte[] encoded = new byte[ENCODED_BUFFER_SIZE];
057        private int position = 0; // current index into encoded buffer
058        private int size = 0; // current size of encoded buffer
059    
060        private final ByteQueue q = new ByteQueue();
061    
062        private boolean eof; // end of file or pad character reached
063    
064        public Base64InputStream(InputStream in) {
065            this(in, false);
066        }
067    
068        public Base64InputStream(InputStream in, boolean strict) {
069            if (in == null)
070                throw new IllegalArgumentException();
071    
072            this.in = in;
073            this.strict = strict;
074        }
075    
076        @Override
077        public int read() throws IOException {
078            if (closed)
079                throw new IOException("Base64InputStream has been closed");
080    
081            while (true) {
082                int bytes = read0(singleByte, 0, 1);
083                if (bytes == EOF)
084                    return EOF;
085    
086                if (bytes == 1)
087                    return singleByte[0] & 0xff;
088            }
089        }
090    
091        @Override
092        public int read(byte[] buffer) throws IOException {
093            if (closed)
094                throw new IOException("Base64InputStream has been closed");
095    
096            if (buffer == null)
097                throw new NullPointerException();
098    
099            if (buffer.length == 0)
100                return 0;
101    
102            return read0(buffer, 0, buffer.length);
103        }
104    
105        @Override
106        public int read(byte[] buffer, int offset, int length) throws IOException {
107            if (closed)
108                throw new IOException("Base64InputStream has been closed");
109    
110            if (buffer == null)
111                throw new NullPointerException();
112    
113            if (offset < 0 || length < 0 || offset + length > buffer.length)
114                throw new IndexOutOfBoundsException();
115    
116            if (length == 0)
117                return 0;
118    
119            return read0(buffer, offset, offset + length);
120        }
121    
122        @Override
123        public void close() throws IOException {
124            if (closed)
125                return;
126    
127            closed = true;
128        }
129    
130        private int read0(final byte[] buffer, final int from, final int to)
131                throws IOException {
132            int index = from; // index into given buffer
133    
134            // check if a previous invocation left decoded bytes in the queue
135    
136            int qCount = q.count();
137            while (qCount-- > 0 && index < to) {
138                buffer[index++] = q.dequeue();
139            }
140    
141            // eof or pad reached?
142    
143            if (eof)
144                return index == from ? EOF : index - from;
145    
146            // decode into given buffer
147    
148            int data = 0; // holds decoded data; up to four sextets
149            int sextets = 0; // number of sextets
150    
151            while (index < to) {
152                // make sure buffer not empty
153    
154                while (position == size) {
155                    int n = in.read(encoded, 0, encoded.length);
156                    if (n == EOF) {
157                        eof = true;
158    
159                        if (sextets != 0) {
160                            // error in encoded data
161                            handleUnexpectedEof(sextets);
162                        }
163    
164                        return index == from ? EOF : index - from;
165                    } else if (n > 0) {
166                        position = 0;
167                        size = n;
168                    } else {
169                        assert n == 0;
170                    }
171                }
172    
173                // decode buffer
174    
175                while (position < size && index < to) {
176                    int value = encoded[position++] & 0xff;
177    
178                    if (value == BASE64_PAD) {
179                        index = decodePad(data, sextets, buffer, index, to);
180                        return index - from;
181                    }
182    
183                    int decoded = BASE64_DECODE[value];
184                    if (decoded < 0) // -1: not a base64 char
185                        continue;
186    
187                    data = (data << 6) | decoded;
188                    sextets++;
189    
190                    if (sextets == 4) {
191                        sextets = 0;
192    
193                        byte b1 = (byte) (data >>> 16);
194                        byte b2 = (byte) (data >>> 8);
195                        byte b3 = (byte) data;
196    
197                        if (index < to - 2) {
198                            buffer[index++] = b1;
199                            buffer[index++] = b2;
200                            buffer[index++] = b3;
201                        } else {
202                            if (index < to - 1) {
203                                buffer[index++] = b1;
204                                buffer[index++] = b2;
205                                q.enqueue(b3);
206                            } else if (index < to) {
207                                buffer[index++] = b1;
208                                q.enqueue(b2);
209                                q.enqueue(b3);
210                            } else {
211                                q.enqueue(b1);
212                                q.enqueue(b2);
213                                q.enqueue(b3);
214                            }
215    
216                            assert index == to;
217                            return to - from;
218                        }
219                    }
220                }
221            }
222    
223            assert sextets == 0;
224            assert index == to;
225            return to - from;
226        }
227    
228        private int decodePad(int data, int sextets, final byte[] buffer,
229                int index, final int end) throws IOException {
230            eof = true;
231    
232            if (sextets == 2) {
233                // one byte encoded as "XY=="
234    
235                byte b = (byte) (data >>> 4);
236                if (index < end) {
237                    buffer[index++] = b;
238                } else {
239                    q.enqueue(b);
240                }
241            } else if (sextets == 3) {
242                // two bytes encoded as "XYZ="
243    
244                byte b1 = (byte) (data >>> 10);
245                byte b2 = (byte) ((data >>> 2) & 0xFF);
246    
247                if (index < end - 1) {
248                    buffer[index++] = b1;
249                    buffer[index++] = b2;
250                } else if (index < end) {
251                    buffer[index++] = b1;
252                    q.enqueue(b2);
253                } else {
254                    q.enqueue(b1);
255                    q.enqueue(b2);
256                }
257            } else {
258                // error in encoded data
259                handleUnexpecedPad(sextets);
260            }
261    
262            return index;
263        }
264    
265        private void handleUnexpectedEof(int sextets) throws IOException {
266            if (strict)
267                throw new IOException("unexpected end of file");
268            else
269                log.warn("unexpected end of file; dropping " + sextets
270                        + " sextet(s)");
271        }
272    
273        private void handleUnexpecedPad(int sextets) throws IOException {
274            if (strict)
275                throw new IOException("unexpected padding character");
276            else
277                log.warn("unexpected padding character; dropping " + sextets
278                        + " sextet(s)");
279        }
280    }