Classifier: improved context

pull/191/head
M66B 4 years ago
parent f9c09da505
commit d6dfe44a9c

@ -24,6 +24,7 @@ import android.content.SharedPreferences;
import android.os.Build; import android.os.Build;
import android.text.TextUtils; import android.text.TextUtils;
import androidx.annotation.NonNull;
import androidx.preference.PreferenceManager; import androidx.preference.PreferenceManager;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@ -63,48 +64,11 @@ public class MessageClassifier {
if (target != null && !canClassify(target.type)) if (target != null && !canClassify(target.type))
return; return;
File file = message.getFile(context);
if (!file.exists())
return;
long start = new Date().getTime(); long start = new Date().getTime();
// Build text to classify // Build text to classify
StringBuilder sb = new StringBuilder(); List<String> texts = getTexts(message, context);
if (texts.size() == 0)
List<Address> addresses = new ArrayList<>();
if (message.from != null)
addresses.addAll(Arrays.asList(message.from));
if (message.to != null)
addresses.addAll(Arrays.asList(message.to));
if (message.cc != null)
addresses.addAll(Arrays.asList(message.cc));
if (message.bcc != null)
addresses.addAll(Arrays.asList(message.bcc));
if (message.reply != null)
addresses.addAll(Arrays.asList(message.reply));
for (Address address : addresses) {
String email = ((InternetAddress) address).getAddress();
String name = ((InternetAddress) address).getAddress();
if (!TextUtils.isEmpty(email)) {
sb.append(email).append('\n');
int at = email.indexOf('@');
String domain = (at < 0 ? null : email.substring(at + 1));
if (!TextUtils.isEmpty(domain))
sb.append(domain).append('\n');
}
if (!TextUtils.isEmpty(name))
sb.append(name).append('\n');
}
if (message.subject != null)
sb.append(message.subject).append('\n');
String text = HtmlHelper.getFullText(file);
sb.append(text);
if (sb.length() == 0)
return; return;
// Load data if needed // Load data if needed
@ -116,8 +80,8 @@ public class MessageClassifier {
if (!wordClassFrequency.containsKey(folder.account)) if (!wordClassFrequency.containsKey(folder.account))
wordClassFrequency.put(folder.account, new HashMap<>()); wordClassFrequency.put(folder.account, new HashMap<>());
// Classify text // Classify texts
String classified = classify(folder.account, folder.name, sb.toString(), target == null, context); String classified = classify(folder.account, folder.name, texts, target == null, context);
long elapsed = new Date().getTime() - start; long elapsed = new Date().getTime() - start;
EntityLog.log(context, "Classifier" + EntityLog.log(context, "Classifier" +
@ -139,7 +103,7 @@ public class MessageClassifier {
dirty = true; dirty = true;
// Auto classify // Auto classify message
if (classified != null && if (classified != null &&
!classified.equals(folder.name) && !classified.equals(folder.name) &&
!message.auto_classified && !message.auto_classified &&
@ -165,47 +129,90 @@ public class MessageClassifier {
} }
} }
private static String classify(long account, String currentClass, String text, boolean added, Context context) { @NonNull
int maxMessages = 0; private static List<String> getTexts(@NonNull EntityMessage message, @NonNull Context context) throws IOException {
for (String clazz : classMessages.get(account).keySet()) { List<String> texts = new ArrayList<>();
int count = classMessages.get(account).get(clazz);
if (count > maxMessages) File file = message.getFile(context);
maxMessages = count; if (!file.exists())
return texts;
List<Address> addresses = new ArrayList<>();
if (message.from != null)
addresses.addAll(Arrays.asList(message.from));
if (message.to != null)
addresses.addAll(Arrays.asList(message.to));
if (message.cc != null)
addresses.addAll(Arrays.asList(message.cc));
if (message.bcc != null)
addresses.addAll(Arrays.asList(message.bcc));
if (message.reply != null)
addresses.addAll(Arrays.asList(message.reply));
for (Address address : addresses) {
String email = ((InternetAddress) address).getAddress();
String name = ((InternetAddress) address).getPersonal();
if (!TextUtils.isEmpty(email))
texts.add(email);
if (!TextUtils.isEmpty(name))
texts.add(name);
} }
if (message.subject != null)
texts.add(message.subject);
String text = HtmlHelper.getFullText(file);
texts.add(text);
return texts;
}
private static String classify(long account, @NonNull String currentClass, @NonNull List<String> texts, boolean added, @NonNull Context context) {
State state = new State(); State state = new State();
// First word Log.i("Classifier texts=" + texts.size());
process(account, currentClass, added, null, state); for (String text : texts) {
// First word
// Process words processWord(account, added, null, state);
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) {
java.text.BreakIterator boundary = java.text.BreakIterator.getWordInstance(); // Process words
boundary.setText(text); if (Build.VERSION.SDK_INT < Build.VERSION_CODES.N) {
int start = boundary.first(); java.text.BreakIterator boundary = java.text.BreakIterator.getWordInstance();
for (int end = boundary.next(); end != java.text.BreakIterator.DONE; end = boundary.next()) { boundary.setText(text);
String word = text.substring(start, end); int start = boundary.first();
process(account, currentClass, added, word, state); for (int end = boundary.next(); end != java.text.BreakIterator.DONE; end = boundary.next()) {
start = end; String word = text.substring(start, end);
} processWord(account, added, word, state);
} else { start = end;
// The ICU break iterator works better for Chinese texts }
android.icu.text.BreakIterator boundary = android.icu.text.BreakIterator.getWordInstance(); } else {
boundary.setText(text); // The ICU break iterator works better for Chinese texts
int start = boundary.first(); android.icu.text.BreakIterator boundary = android.icu.text.BreakIterator.getWordInstance();
for (int end = boundary.next(); end != android.icu.text.BreakIterator.DONE; end = boundary.next()) { boundary.setText(text);
String word = text.substring(start, end); int start = boundary.first();
process(account, currentClass, added, word, state); for (int end = boundary.next(); end != android.icu.text.BreakIterator.DONE; end = boundary.next()) {
start = end; String word = text.substring(start, end);
processWord(account, added, word, state);
start = end;
}
} }
} }
// Last word // final word
process(account, currentClass, added, null, state); processWord(account, added, null, state);
updateFrequencies(account, currentClass, added, state);
if (!added) if (!added)
return null; return null;
int maxMessages = 0;
for (String clazz : classMessages.get(account).keySet()) {
int count = classMessages.get(account).get(clazz);
if (count > maxMessages)
maxMessages = count;
}
if (maxMessages == 0) { if (maxMessages == 0) {
Log.i("Classifier no messages account=" + account); Log.i("Classifier no messages account=" + account);
return null; return null;
@ -213,7 +220,7 @@ public class MessageClassifier {
// Calculate chance per class // Calculate chance per class
DB db = DB.getInstance(context); DB db = DB.getInstance(context);
int words = state.words.size() - 2; int words = state.words.size() - texts.size() - 1;
List<Chance> chances = new ArrayList<>(); List<Chance> chances = new ArrayList<>();
for (String clazz : state.classStats.keySet()) { for (String clazz : state.classStats.keySet()) {
EntityFolder folder = db.folder().getFolderByName(account, clazz); EntityFolder folder = db.folder().getFolderByName(account, clazz);
@ -234,7 +241,7 @@ public class MessageClassifier {
} }
if (BuildConfig.DEBUG) if (BuildConfig.DEBUG)
Log.i("Classifier words=" + TextUtils.join(", ", state.words)); Log.i("Classifier words=" + state.words.size() + " " + TextUtils.join(", ", state.words));
if (chances.size() <= 1) if (chances.size() <= 1)
return null; return null;
@ -268,17 +275,20 @@ public class MessageClassifier {
return classification; return classification;
} }
private static void process(long account, String currentClass, boolean added, String word, State state) { private static void processWord(long account, boolean added, String word, State state) {
if (word != null) { if (word != null) {
word = word.trim().toLowerCase(); word = word.trim().toLowerCase();
if (word.length() < 2 || word.matches(".*\\d.*"))
if (word.length() < 2 ||
state.words.contains(word) ||
word.matches(".*\\d.*"))
return; return;
} }
state.words.add(word); if (word != null ||
state.words.size() == 0 ||
state.words.get(state.words.size() - 1) != null)
state.words.add(word);
if (!added)
return;
if (state.words.size() < 3) if (state.words.size() < 3)
return; return;
@ -287,50 +297,74 @@ public class MessageClassifier {
String current = state.words.get(state.words.size() - 2); String current = state.words.get(state.words.size() - 2);
String after = state.words.get(state.words.size() - 1); String after = state.words.get(state.words.size() - 1);
if (current == null)
return;
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(current); Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(current);
if (added) { if (classFrequency == null)
if (classFrequency == null) { return;
classFrequency = new HashMap<>();
wordClassFrequency.get(account).put(current, classFrequency); for (String clazz : classFrequency.keySet()) {
Frequency frequency = classFrequency.get(clazz);
if (frequency.count <= 0)
continue;
Stat stat = state.classStats.get(clazz);
if (stat == null) {
stat = new Stat();
state.classStats.put(clazz, stat);
} }
for (String clazz : classFrequency.keySet()) { int c = (frequency.count - frequency.duplicates);
Frequency frequency = classFrequency.get(clazz); Integer b = (before == null ? null : frequency.before.get(before));
if (frequency.count > 0) { Integer a = (after == null ? null : frequency.after.get(after));
Stat stat = state.classStats.get(clazz); double f = (c +
if (stat == null) { (b == null ? 2 * c : 2.0 * b / frequency.count * c) +
stat = new Stat(); (a == null ? 2 * c : 2.0 * a / frequency.count * c)) / 5.0;
state.classStats.put(clazz, stat); //Log.i("Classifier " +
} // before + "/" + b + "/" + frequency.before.get(before) + " " +
// after + "/" + a + "/" + frequency.after.get(after) + " " +
// current + "/" + c + "=" + frequency.count + "-" + frequency.duplicates +
// " f=" + f);
stat.totalFrequency += f;
stat.matchedWords++;
if (BuildConfig.DEBUG)
stat.words.add(current + "=" + f);
}
}
int c = frequency.count; private static void updateFrequencies(long account, @NonNull String clazz, boolean added, @NonNull State state) {
Integer b = (before == null ? null : frequency.before.get(before)); for (int i = 1; i < state.words.size() - 1; i++) {
Integer a = (after == null ? null : frequency.after.get(after)); String before = state.words.get(i - 1);
double f = ((b == null ? 0 : 2 * b) + c + (a == null ? 0 : 2 * a)) / 5.0; String current = state.words.get(i);
stat.totalFrequency += f; String after = state.words.get(i + 1);
stat.matchedWords++; if (current == null)
if (stat.matchedWords > state.maxMatchedWords) continue;
state.maxMatchedWords = stat.matchedWords;
if (BuildConfig.DEBUG) Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(current);
stat.words.add(current); if (added) {
if (classFrequency == null) {
classFrequency = new HashMap<>();
wordClassFrequency.get(account).put(current, classFrequency);
} }
Frequency c = classFrequency.get(clazz);
if (c == null) {
c = new Frequency();
classFrequency.put(clazz, c);
}
c.add(before, after, 1, state.words.indexOf(current) < i);
} else {
Frequency c = (classFrequency == null ? null : classFrequency.get(clazz));
if (c != null)
c.add(before, after, -1, state.words.indexOf(current) < i);
} }
Frequency c = classFrequency.get(currentClass);
if (c == null)
c = new Frequency();
c.add(before, after, 1);
classFrequency.put(currentClass, c);
} else {
Frequency c = (classFrequency == null ? null : classFrequency.get(currentClass));
if (c != null)
c.add(before, after, -1);
} }
} }
static synchronized void save(Context context) throws JSONException, IOException { static synchronized void save(@NonNull Context context) throws JSONException, IOException {
if (!dirty) if (!dirty)
return; return;
@ -341,7 +375,7 @@ public class MessageClassifier {
Log.i("Classifier data saved"); Log.i("Classifier data saved");
} }
private static synchronized void load(Context context) throws IOException, JSONException { private static synchronized void load(@NonNull Context context) throws IOException, JSONException {
if (loaded || dirty) if (loaded || dirty)
return; return;
@ -357,27 +391,28 @@ public class MessageClassifier {
Log.i("Classifier data loaded"); Log.i("Classifier data loaded");
} }
static synchronized void clear(Context context) { static synchronized void clear(@NonNull Context context) {
wordClassFrequency.clear(); wordClassFrequency.clear();
dirty = true; dirty = true;
Log.i("Classifier data cleared"); Log.i("Classifier data cleared");
} }
static boolean isEnabled(Context context) { static boolean isEnabled(@NonNull Context context) {
SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context); SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context);
return prefs.getBoolean("classification", false); return prefs.getBoolean("classification", false);
} }
static boolean canClassify(String folderType) { static boolean canClassify(@NonNull String folderType) {
return EntityFolder.INBOX.equals(folderType) || return EntityFolder.INBOX.equals(folderType) ||
EntityFolder.JUNK.equals(folderType) || EntityFolder.JUNK.equals(folderType) ||
EntityFolder.USER.equals(folderType); EntityFolder.USER.equals(folderType);
} }
static File getFile(Context context) { static File getFile(@NonNull Context context) {
return new File(context.getFilesDir(), "classifier.json"); return new File(context.getFilesDir(), "classifier.json");
} }
@NonNull
static JSONObject toJson() throws JSONException { static JSONObject toJson() throws JSONException {
JSONArray jmessages = new JSONArray(); JSONArray jmessages = new JSONArray();
for (Long account : classMessages.keySet()) for (Long account : classMessages.keySet())
@ -399,7 +434,8 @@ public class MessageClassifier {
jword.put("account", account); jword.put("account", account);
jword.put("word", word); jword.put("word", word);
jword.put("class", clazz); jword.put("class", clazz);
jword.put("frequency", f.count); jword.put("count", f.count);
jword.put("dup", f.duplicates);
jword.put("before", from(f.before)); jword.put("before", from(f.before));
jword.put("after", from(f.after)); jword.put("after", from(f.after));
jwords.put(jword); jwords.put(jword);
@ -407,23 +443,24 @@ public class MessageClassifier {
} }
JSONObject jroot = new JSONObject(); JSONObject jroot = new JSONObject();
jroot.put("version", 1); jroot.put("version", 2);
jroot.put("messages", jmessages); jroot.put("messages", jmessages);
jroot.put("words", jwords); jroot.put("words", jwords);
return jroot; return jroot;
} }
private static JSONObject from(Map<String, Integer> map) throws JSONException { @NonNull
private static JSONObject from(@NonNull Map<String, Integer> map) throws JSONException {
JSONObject jmap = new JSONObject(); JSONObject jmap = new JSONObject();
for (String key : map.keySet()) for (String key : map.keySet())
jmap.put(key, map.get(key)); jmap.put(key, map.get(key));
return jmap; return jmap;
} }
static void fromJson(JSONObject jroot) throws JSONException { static void fromJson(@NonNull JSONObject jroot) throws JSONException {
int version = jroot.optInt("version"); int version = jroot.optInt("version");
if (version < 1) if (version < 2)
return; return;
JSONArray jmessages = jroot.getJSONArray("messages"); JSONArray jmessages = jroot.getJSONArray("messages");
@ -443,23 +480,28 @@ public class MessageClassifier {
long account = jword.getLong("account"); long account = jword.getLong("account");
if (!wordClassFrequency.containsKey(account)) if (!wordClassFrequency.containsKey(account))
wordClassFrequency.put(account, new HashMap<>()); wordClassFrequency.put(account, new HashMap<>());
String word = jword.getString("word"); if (jword.has("word")) {
Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word); String word = jword.getString("word");
if (classFrequency == null) { Map<String, Frequency> classFrequency = wordClassFrequency.get(account).get(word);
classFrequency = new HashMap<>(); if (classFrequency == null) {
wordClassFrequency.get(account).put(word, classFrequency); classFrequency = new HashMap<>();
} wordClassFrequency.get(account).put(word, classFrequency);
Frequency f = new Frequency(); }
f.count = jword.getInt("frequency"); Frequency f = new Frequency();
if (jword.has("before")) f.count = jword.getInt("count");
f.before = from(jword.getJSONObject("before")); f.duplicates = jword.optInt("dup");
if (jword.has("after")) if (jword.has("before"))
f.after = from(jword.getJSONObject("after")); f.before = from(jword.getJSONObject("before"));
classFrequency.put(jword.getString("class"), f); if (jword.has("after"))
f.after = from(jword.getJSONObject("after"));
classFrequency.put(jword.getString("class"), f);
} else
Log.w("No words account=" + account);
} }
} }
private static Map<String, Integer> from(JSONObject jmap) throws JSONException { @NonNull
private static Map<String, Integer> from(@NonNull JSONObject jmap) throws JSONException {
Map<String, Integer> result = new HashMap<>(jmap.length()); Map<String, Integer> result = new HashMap<>(jmap.length());
Iterator<String> iterator = jmap.keys(); Iterator<String> iterator = jmap.keys();
while (iterator.hasNext()) { while (iterator.hasNext()) {
@ -470,22 +512,25 @@ public class MessageClassifier {
} }
private static class State { private static class State {
private int maxMatchedWords = 0; private final List<String> words = new ArrayList<>();
private List<String> words = new ArrayList<>(); private final Map<String, Stat> classStats = new HashMap<>();
private Map<String, Stat> classStats = new HashMap<>();
} }
private static class Frequency { private static class Frequency {
private int count = 0; private int count = 0;
private int duplicates = 0;
private Map<String, Integer> before = new HashMap<>(); private Map<String, Integer> before = new HashMap<>();
private Map<String, Integer> after = new HashMap<>(); private Map<String, Integer> after = new HashMap<>();
private void add(String b, String a, int c) { private void add(String b, String a, int c, boolean duplicate) {
if (count + c < 0) if (count + c < 0)
return; return;
count += c; count += c;
if (duplicate)
duplicates += c;
if (b != null) { if (b != null) {
Integer x = before.get(b); Integer x = before.get(b);
before.put(b, (x == null ? 0 : x) + c); before.put(b, (x == null ? 0 : x) + c);
@ -501,7 +546,7 @@ public class MessageClassifier {
private static class Stat { private static class Stat {
private int matchedWords = 0; private int matchedWords = 0;
private double totalFrequency = 0; private double totalFrequency = 0;
private List<String> words = new ArrayList<>(); private final List<String> words = new ArrayList<>();
} }
private static class Chance { private static class Chance {

Loading…
Cancel
Save