Persist classifier model

pull/190/head
M66B 4 years ago
parent 1666326c59
commit 5fdc634b2c

@ -26,6 +26,9 @@ import android.text.TextUtils;
import androidx.preference.PreferenceManager; import androidx.preference.PreferenceManager;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -38,6 +41,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
public class MessageClassifier { public class MessageClassifier {
private static boolean loaded = false;
private static Map<String, Integer> classMessages = new HashMap<>(); private static Map<String, Integer> classMessages = new HashMap<>();
private static Map<String, Map<String, Integer>> wordClassFrequency = new HashMap<>(); private static Map<String, Map<String, Integer>> wordClassFrequency = new HashMap<>();
@ -45,10 +49,15 @@ public class MessageClassifier {
private static final double CHANCE_THRESHOLD = 2.0; private static final double CHANCE_THRESHOLD = 2.0;
static String classify(EntityMessage message, boolean added, Context context) { static String classify(EntityMessage message, boolean added, Context context) {
SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context); if (!isEnabled(context))
if (!prefs.getBoolean("classify", BuildConfig.DEBUG))
return null; return null;
try {
load(context);
} catch (Throwable ex) {
Log.e(ex);
}
DB db = DB.getInstance(context); DB db = DB.getInstance(context);
EntityFolder folder = db.folder().getFolder(message.folder); EntityFolder folder = db.folder().getFolder(message.folder);
@ -198,6 +207,87 @@ public class MessageClassifier {
return classification; return classification;
} }
static synchronized void save(Context context) throws JSONException, IOException {
if (!isEnabled(context))
return;
JSONArray jmessages = new JSONArray();
for (String clazz : classMessages.keySet()) {
JSONObject jmessage = new JSONObject();
jmessage.put("class", clazz);
jmessage.put("count", classMessages.get(clazz));
jmessages.put(jmessage);
}
JSONArray jwords = new JSONArray();
for (String word : wordClassFrequency.keySet())
for (String clazz : wordClassFrequency.get(word).keySet()) {
JSONObject jword = new JSONObject();
jword.put("word", word);
jword.put("class", clazz);
jword.put("frequency", wordClassFrequency.get(word).get(clazz));
jwords.put(jword);
}
JSONObject jroot = new JSONObject();
jroot.put("messages", jmessages);
jroot.put("words", jwords);
File file = getFile(context);
Helper.writeText(file, jroot.toString(2));
loaded = false;
Log.i("Classifier saved classes=" + classMessages.size() + " words=" + wordClassFrequency.size());
}
private static synchronized void load(Context context) throws IOException, JSONException {
if (loaded)
return;
if (!isEnabled(context))
return;
classMessages.clear();
wordClassFrequency.clear();
File file = getFile(context);
if (file.exists()) {
String json = Helper.readText(file);
JSONObject jroot = new JSONObject(json);
JSONArray jmessages = jroot.getJSONArray("messages");
for (int m = 0; m < jmessages.length(); m++) {
JSONObject jmessage = (JSONObject) jmessages.get(m);
classMessages.put(jmessage.getString("class"), jmessage.getInt("count"));
}
JSONArray jwords = jroot.getJSONArray("words");
for (int w = 0; w < jwords.length(); w++) {
JSONObject jword = (JSONObject) jwords.get(w);
String word = jword.getString("word");
Map<String, Integer> classFrequency = wordClassFrequency.get(word);
if (classFrequency == null) {
classFrequency = new HashMap<>();
wordClassFrequency.put(word, classFrequency);
}
classFrequency.put(jword.getString("class"), jword.getInt("frequency"));
}
}
loaded = true;
Log.i("Classifier loaded classes=" + classMessages.size() + " words=" + wordClassFrequency.size());
}
private static boolean isEnabled(Context context) {
SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context);
return prefs.getBoolean("classify", BuildConfig.DEBUG);
}
private static File getFile(Context context) {
return new File(context.getFilesDir(), "classifier.json");
}
private static class Stat { private static class Stat {
int matchedWords = 0; int matchedWords = 0;
int totalFrequency = 0; int totalFrequency = 0;

@ -735,6 +735,17 @@ public class ServiceSynchronize extends ServiceBase implements SharedPreferences
liveAccountNetworkState.postDestroy(); liveAccountNetworkState.postDestroy();
executor.submit(new Runnable() {
@Override
public void run() {
try {
MessageClassifier.save(ServiceSynchronize.this);
} catch (Throwable ex) {
Log.e(ex);
}
}
});
TTSHelper.shutdown(); TTSHelper.shutdown();
try { try {

Loading…
Cancel
Save