diff --git a/app/src/main/java/eu/faircode/email/MessageClassifier.java b/app/src/main/java/eu/faircode/email/MessageClassifier.java index 29ecf5726a..cb3faf852c 100644 --- a/app/src/main/java/eu/faircode/email/MessageClassifier.java +++ b/app/src/main/java/eu/faircode/email/MessageClassifier.java @@ -24,6 +24,7 @@ import android.content.SharedPreferences; import android.os.Build; import android.text.TextUtils; +import androidx.annotation.NonNull; import androidx.preference.PreferenceManager; import org.jetbrains.annotations.NotNull; @@ -63,48 +64,11 @@ public class MessageClassifier { if (target != null && !canClassify(target.type)) return; - File file = message.getFile(context); - if (!file.exists()) - return; - long start = new Date().getTime(); // Build text to classify - StringBuilder sb = new StringBuilder(); - - List
addresses = new ArrayList<>(); - if (message.from != null) - addresses.addAll(Arrays.asList(message.from)); - if (message.to != null) - addresses.addAll(Arrays.asList(message.to)); - if (message.cc != null) - addresses.addAll(Arrays.asList(message.cc)); - if (message.bcc != null) - addresses.addAll(Arrays.asList(message.bcc)); - if (message.reply != null) - addresses.addAll(Arrays.asList(message.reply)); - - for (Address address : addresses) { - String email = ((InternetAddress) address).getAddress(); - String name = ((InternetAddress) address).getAddress(); - if (!TextUtils.isEmpty(email)) { - sb.append(email).append('\n'); - int at = email.indexOf('@'); - String domain = (at < 0 ? null : email.substring(at + 1)); - if (!TextUtils.isEmpty(domain)) - sb.append(domain).append('\n'); - } - if (!TextUtils.isEmpty(name)) - sb.append(name).append('\n'); - } - - if (message.subject != null) - sb.append(message.subject).append('\n'); - - String text = HtmlHelper.getFullText(file); - sb.append(text); - - if (sb.length() == 0) + List texts = getTexts(message, context); + if (texts.size() == 0) return; // Load data if needed @@ -116,8 +80,8 @@ public class MessageClassifier { if (!wordClassFrequency.containsKey(folder.account)) wordClassFrequency.put(folder.account, new HashMap<>()); - // Classify text - String classified = classify(folder.account, folder.name, sb.toString(), target == null, context); + // Classify texts + String classified = classify(folder.account, folder.name, texts, target == null, context); long elapsed = new Date().getTime() - start; EntityLog.log(context, "Classifier" + @@ -139,7 +103,7 @@ public class MessageClassifier { dirty = true; - // Auto classify + // Auto classify message if (classified != null && !classified.equals(folder.name) && !message.auto_classified && @@ -165,47 +129,90 @@ public class MessageClassifier { } } - private static String classify(long account, String currentClass, String text, boolean added, Context context) { - int maxMessages = 0; - for (String clazz : classMessages.get(account).keySet()) { - int count = classMessages.get(account).get(clazz); - if (count > maxMessages) - maxMessages = count; + @NonNull + private static List getTexts(@NonNull EntityMessage message, @NonNull Context context) throws IOException { + List texts = new ArrayList<>(); + + File file = message.getFile(context); + if (!file.exists()) + return texts; + + List
addresses = new ArrayList<>(); + if (message.from != null) + addresses.addAll(Arrays.asList(message.from)); + if (message.to != null) + addresses.addAll(Arrays.asList(message.to)); + if (message.cc != null) + addresses.addAll(Arrays.asList(message.cc)); + if (message.bcc != null) + addresses.addAll(Arrays.asList(message.bcc)); + if (message.reply != null) + addresses.addAll(Arrays.asList(message.reply)); + + for (Address address : addresses) { + String email = ((InternetAddress) address).getAddress(); + String name = ((InternetAddress) address).getPersonal(); + if (!TextUtils.isEmpty(email)) + texts.add(email); + if (!TextUtils.isEmpty(name)) + texts.add(name); } + if (message.subject != null) + texts.add(message.subject); + + String text = HtmlHelper.getFullText(file); + texts.add(text); + + return texts; + } + + private static String classify(long account, @NonNull String currentClass, @NonNull List texts, boolean added, @NonNull Context context) { State state = new State(); - // First word - process(account, currentClass, added, null, state); - - // Process words - if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) { - java.text.BreakIterator boundary = java.text.BreakIterator.getWordInstance(); - boundary.setText(text); - int start = boundary.first(); - for (int end = boundary.next(); end != java.text.BreakIterator.DONE; end = boundary.next()) { - String word = text.substring(start, end); - process(account, currentClass, added, word, state); - start = end; - } - } else { - // The ICU break iterator works better for Chinese texts - android.icu.text.BreakIterator boundary = android.icu.text.BreakIterator.getWordInstance(); - boundary.setText(text); - int start = boundary.first(); - for (int end = boundary.next(); end != android.icu.text.BreakIterator.DONE; end = boundary.next()) { - String word = text.substring(start, end); - process(account, currentClass, added, word, state); - start = end; + Log.i("Classifier texts=" + texts.size()); + for (String text : texts) { + // First word + processWord(account, added, null, state); + + // Process words + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) { + java.text.BreakIterator boundary = java.text.BreakIterator.getWordInstance(); + boundary.setText(text); + int start = boundary.first(); + for (int end = boundary.next(); end != java.text.BreakIterator.DONE; end = boundary.next()) { + String word = text.substring(start, end); + processWord(account, added, word, state); + start = end; + } + } else { + // The ICU break iterator works better for Chinese texts + android.icu.text.BreakIterator boundary = android.icu.text.BreakIterator.getWordInstance(); + boundary.setText(text); + int start = boundary.first(); + for (int end = boundary.next(); end != android.icu.text.BreakIterator.DONE; end = boundary.next()) { + String word = text.substring(start, end); + processWord(account, added, word, state); + start = end; + } } } - // Last word - process(account, currentClass, added, null, state); + // final word + processWord(account, added, null, state); + + updateFrequencies(account, currentClass, added, state); if (!added) return null; + int maxMessages = 0; + for (String clazz : classMessages.get(account).keySet()) { + int count = classMessages.get(account).get(clazz); + if (count > maxMessages) + maxMessages = count; + } + if (maxMessages == 0) { Log.i("Classifier no messages account=" + account); return null; @@ -213,7 +220,7 @@ public class MessageClassifier { // Calculate chance per class DB db = DB.getInstance(context); - int words = state.words.size() - 2; + int words = state.words.size() - texts.size() - 1; List chances = new ArrayList<>(); for (String clazz : state.classStats.keySet()) { EntityFolder folder = db.folder().getFolderByName(account, clazz); @@ -234,7 +241,7 @@ public class MessageClassifier { } if (BuildConfig.DEBUG) - Log.i("Classifier words=" + TextUtils.join(", ", state.words)); + Log.i("Classifier words=" + state.words.size() + " " + TextUtils.join(", ", state.words)); if (chances.size() <= 1) return null; @@ -268,17 +275,20 @@ public class MessageClassifier { return classification; } - private static void process(long account, String currentClass, boolean added, String word, State state) { + private static void processWord(long account, boolean added, String word, State state) { if (word != null) { word = word.trim().toLowerCase(); - - if (word.length() < 2 || - state.words.contains(word) || - word.matches(".*\\d.*")) + if (word.length() < 2 || word.matches(".*\\d.*")) return; } - state.words.add(word); + if (word != null || + state.words.size() == 0 || + state.words.get(state.words.size() - 1) != null) + state.words.add(word); + + if (!added) + return; if (state.words.size() < 3) return; @@ -287,50 +297,74 @@ public class MessageClassifier { String current = state.words.get(state.words.size() - 2); String after = state.words.get(state.words.size() - 1); + if (current == null) + return; + Map classFrequency = wordClassFrequency.get(account).get(current); - if (added) { - if (classFrequency == null) { - classFrequency = new HashMap<>(); - wordClassFrequency.get(account).put(current, classFrequency); + if (classFrequency == null) + return; + + for (String clazz : classFrequency.keySet()) { + Frequency frequency = classFrequency.get(clazz); + if (frequency.count <= 0) + continue; + + Stat stat = state.classStats.get(clazz); + if (stat == null) { + stat = new Stat(); + state.classStats.put(clazz, stat); } - for (String clazz : classFrequency.keySet()) { - Frequency frequency = classFrequency.get(clazz); - if (frequency.count > 0) { - Stat stat = state.classStats.get(clazz); - if (stat == null) { - stat = new Stat(); - state.classStats.put(clazz, stat); - } + int c = (frequency.count - frequency.duplicates); + Integer b = (before == null ? null : frequency.before.get(before)); + Integer a = (after == null ? null : frequency.after.get(after)); + double f = (c + + (b == null ? 2 * c : 2.0 * b / frequency.count * c) + + (a == null ? 2 * c : 2.0 * a / frequency.count * c)) / 5.0; + //Log.i("Classifier " + + // before + "/" + b + "/" + frequency.before.get(before) + " " + + // after + "/" + a + "/" + frequency.after.get(after) + " " + + // current + "/" + c + "=" + frequency.count + "-" + frequency.duplicates + + // " f=" + f); + + stat.totalFrequency += f; + stat.matchedWords++; + + if (BuildConfig.DEBUG) + stat.words.add(current + "=" + f); + } + } - int c = frequency.count; - Integer b = (before == null ? null : frequency.before.get(before)); - Integer a = (after == null ? null : frequency.after.get(after)); - double f = ((b == null ? 0 : 2 * b) + c + (a == null ? 0 : 2 * a)) / 5.0; - stat.totalFrequency += f; + private static void updateFrequencies(long account, @NonNull String clazz, boolean added, @NonNull State state) { + for (int i = 1; i < state.words.size() - 1; i++) { + String before = state.words.get(i - 1); + String current = state.words.get(i); + String after = state.words.get(i + 1); - stat.matchedWords++; - if (stat.matchedWords > state.maxMatchedWords) - state.maxMatchedWords = stat.matchedWords; + if (current == null) + continue; - if (BuildConfig.DEBUG) - stat.words.add(current); + Map classFrequency = wordClassFrequency.get(account).get(current); + if (added) { + if (classFrequency == null) { + classFrequency = new HashMap<>(); + wordClassFrequency.get(account).put(current, classFrequency); } + Frequency c = classFrequency.get(clazz); + if (c == null) { + c = new Frequency(); + classFrequency.put(clazz, c); + } + c.add(before, after, 1, state.words.indexOf(current) < i); + } else { + Frequency c = (classFrequency == null ? null : classFrequency.get(clazz)); + if (c != null) + c.add(before, after, -1, state.words.indexOf(current) < i); } - - Frequency c = classFrequency.get(currentClass); - if (c == null) - c = new Frequency(); - c.add(before, after, 1); - classFrequency.put(currentClass, c); - } else { - Frequency c = (classFrequency == null ? null : classFrequency.get(currentClass)); - if (c != null) - c.add(before, after, -1); } } - static synchronized void save(Context context) throws JSONException, IOException { + static synchronized void save(@NonNull Context context) throws JSONException, IOException { if (!dirty) return; @@ -341,7 +375,7 @@ public class MessageClassifier { Log.i("Classifier data saved"); } - private static synchronized void load(Context context) throws IOException, JSONException { + private static synchronized void load(@NonNull Context context) throws IOException, JSONException { if (loaded || dirty) return; @@ -357,27 +391,28 @@ public class MessageClassifier { Log.i("Classifier data loaded"); } - static synchronized void clear(Context context) { + static synchronized void clear(@NonNull Context context) { wordClassFrequency.clear(); dirty = true; Log.i("Classifier data cleared"); } - static boolean isEnabled(Context context) { + static boolean isEnabled(@NonNull Context context) { SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context); return prefs.getBoolean("classification", false); } - static boolean canClassify(String folderType) { + static boolean canClassify(@NonNull String folderType) { return EntityFolder.INBOX.equals(folderType) || EntityFolder.JUNK.equals(folderType) || EntityFolder.USER.equals(folderType); } - static File getFile(Context context) { + static File getFile(@NonNull Context context) { return new File(context.getFilesDir(), "classifier.json"); } + @NonNull static JSONObject toJson() throws JSONException { JSONArray jmessages = new JSONArray(); for (Long account : classMessages.keySet()) @@ -399,7 +434,8 @@ public class MessageClassifier { jword.put("account", account); jword.put("word", word); jword.put("class", clazz); - jword.put("frequency", f.count); + jword.put("count", f.count); + jword.put("dup", f.duplicates); jword.put("before", from(f.before)); jword.put("after", from(f.after)); jwords.put(jword); @@ -407,23 +443,24 @@ public class MessageClassifier { } JSONObject jroot = new JSONObject(); - jroot.put("version", 1); + jroot.put("version", 2); jroot.put("messages", jmessages); jroot.put("words", jwords); return jroot; } - private static JSONObject from(Map map) throws JSONException { + @NonNull + private static JSONObject from(@NonNull Map map) throws JSONException { JSONObject jmap = new JSONObject(); for (String key : map.keySet()) jmap.put(key, map.get(key)); return jmap; } - static void fromJson(JSONObject jroot) throws JSONException { + static void fromJson(@NonNull JSONObject jroot) throws JSONException { int version = jroot.optInt("version"); - if (version < 1) + if (version < 2) return; JSONArray jmessages = jroot.getJSONArray("messages"); @@ -443,23 +480,28 @@ public class MessageClassifier { long account = jword.getLong("account"); if (!wordClassFrequency.containsKey(account)) wordClassFrequency.put(account, new HashMap<>()); - String word = jword.getString("word"); - Map classFrequency = wordClassFrequency.get(account).get(word); - if (classFrequency == null) { - classFrequency = new HashMap<>(); - wordClassFrequency.get(account).put(word, classFrequency); - } - Frequency f = new Frequency(); - f.count = jword.getInt("frequency"); - if (jword.has("before")) - f.before = from(jword.getJSONObject("before")); - if (jword.has("after")) - f.after = from(jword.getJSONObject("after")); - classFrequency.put(jword.getString("class"), f); + if (jword.has("word")) { + String word = jword.getString("word"); + Map classFrequency = wordClassFrequency.get(account).get(word); + if (classFrequency == null) { + classFrequency = new HashMap<>(); + wordClassFrequency.get(account).put(word, classFrequency); + } + Frequency f = new Frequency(); + f.count = jword.getInt("count"); + f.duplicates = jword.optInt("dup"); + if (jword.has("before")) + f.before = from(jword.getJSONObject("before")); + if (jword.has("after")) + f.after = from(jword.getJSONObject("after")); + classFrequency.put(jword.getString("class"), f); + } else + Log.w("No words account=" + account); } } - private static Map from(JSONObject jmap) throws JSONException { + @NonNull + private static Map from(@NonNull JSONObject jmap) throws JSONException { Map result = new HashMap<>(jmap.length()); Iterator iterator = jmap.keys(); while (iterator.hasNext()) { @@ -470,22 +512,25 @@ public class MessageClassifier { } private static class State { - private int maxMatchedWords = 0; - private List words = new ArrayList<>(); - private Map classStats = new HashMap<>(); + private final List words = new ArrayList<>(); + private final Map classStats = new HashMap<>(); } private static class Frequency { private int count = 0; + private int duplicates = 0; private Map before = new HashMap<>(); private Map after = new HashMap<>(); - private void add(String b, String a, int c) { + private void add(String b, String a, int c, boolean duplicate) { if (count + c < 0) return; count += c; + if (duplicate) + duplicates += c; + if (b != null) { Integer x = before.get(b); before.put(b, (x == null ? 0 : x) + c); @@ -501,7 +546,7 @@ public class MessageClassifier { private static class Stat { private int matchedWords = 0; private double totalFrequency = 0; - private List words = new ArrayList<>(); + private final List words = new ArrayList<>(); } private static class Chance {