Classifier: use frequency of word before/after

pull/191/head
M66B 4 years ago
parent 090f8d3a40
commit f7a58d9281

@ -39,7 +39,7 @@ import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -282,47 +282,37 @@ public class MessageClassifier {
for (String clazz : classFrequency.keySet()) { for (String clazz : classFrequency.keySet()) {
Frequency frequency = classFrequency.get(clazz); Frequency frequency = classFrequency.get(clazz);
if (frequency.count > 0) {
Stat stat = state.classStats.get(clazz);
if (stat == null) {
stat = new Stat();
state.classStats.put(clazz, stat);
}
Stat stat = state.classStats.get(clazz); int c = frequency.count;
if (stat == null) { Integer b = (before == null ? null : frequency.before.get(before));
stat = new Stat(); Integer a = (after == null ? null : frequency.after.get(after));
state.classStats.put(clazz, stat); stat.totalFrequency +=
} ((b == null ? 0.0 : (double) b / c) + c + (a == null ? 0.0 : (double) a / c)) / 3;
stat.matchedWords++;
boolean b = (before != null && frequency.before.contains(before));
boolean a = (after != null && frequency.after.contains(after));
if (b && a)
stat.totalFrequency += frequency.count;
else if (b || a)
stat.totalFrequency += frequency.count * 0.5;
else
stat.totalFrequency += frequency.count * 0.25;
if (BuildConfig.DEBUG) stat.matchedWords++;
stat.words.add(current); if (stat.matchedWords > state.maxMatchedWords)
state.maxMatchedWords = stat.matchedWords;
if (stat.matchedWords > state.maxMatchedWords) if (BuildConfig.DEBUG)
state.maxMatchedWords = stat.matchedWords; stat.words.add(current);
}
} }
Frequency c = classFrequency.get(currentClass); Frequency c = classFrequency.get(currentClass);
if (c == null) if (c == null)
c = new Frequency(); c = new Frequency();
c.count++; c.add(before, after, 1);
if (before != null && !c.before.contains(before))
c.before.add(before);
if (after != null && !c.after.contains(after))
c.after.add(after);
classFrequency.put(currentClass, c); classFrequency.put(currentClass, c);
} else { } else {
Frequency c = (classFrequency == null ? null : classFrequency.get(currentClass)); Frequency c = (classFrequency == null ? null : classFrequency.get(currentClass));
if (c != null) if (c != null)
if (c.count > 0) c.add(before, after, -1);
c.count--;
else
classFrequency.remove(currentClass);
} }
} }
@ -411,11 +401,11 @@ public class MessageClassifier {
return jroot; return jroot;
} }
private static JSONArray from(HashSet<String> list) { private static JSONObject from(Map<String, Integer> map) throws JSONException {
JSONArray jarray = new JSONArray(); JSONObject jmap = new JSONObject();
for (String item : list) for (String key : map.keySet())
jarray.put(item); jmap.put(key, map.get(key));
return jarray; return jmap;
} }
static void fromJson(JSONObject jroot) throws JSONException { static void fromJson(JSONObject jroot) throws JSONException {
@ -443,30 +433,50 @@ public class MessageClassifier {
Frequency f = new Frequency(); Frequency f = new Frequency();
f.count = jword.getInt("frequency"); f.count = jword.getInt("frequency");
if (jword.has("before")) if (jword.has("before"))
f.before = from(jword.getJSONArray("before")); f.before = from(jword.getJSONObject("before"));
if (jword.has("after")) if (jword.has("after"))
f.after = from(jword.getJSONArray("after")); f.after = from(jword.getJSONObject("after"));
classFrequency.put(jword.getString("class"), f); classFrequency.put(jword.getString("class"), f);
} }
} }
private static HashSet<String> from(JSONArray jarray) throws JSONException { private static Map<String, Integer> from(JSONObject jmap) throws JSONException {
HashSet<String> result = new HashSet<>(jarray.length()); Map<String, Integer> result = new HashMap<>(jmap.length());
for (int i = 0; i < jarray.length(); i++) Iterator<String> iterator = jmap.keys();
result.add((String) jarray.get(i)); while (iterator.hasNext()) {
String key = iterator.next();
result.put(key, jmap.getInt(key));
}
return result; return result;
} }
private static class State { private static class State {
int maxMatchedWords = 0; private int maxMatchedWords = 0;
List<String> words = new ArrayList<>(); private List<String> words = new ArrayList<>();
Map<String, Stat> classStats = new HashMap<>(); private Map<String, Stat> classStats = new HashMap<>();
} }
private static class Frequency { private static class Frequency {
int count; private int count = 0;
HashSet<String> before = new HashSet<>(); private Map<String, Integer> before = new HashMap<>();
HashSet<String> after = new HashSet<>(); private Map<String, Integer> after = new HashMap<>();
private void add(String b, String a, int c) {
if (count + c < 0)
return;
count += c;
if (b != null) {
Integer x = before.get(b);
before.put(b, (x == null ? 0 : x) + c);
}
if (a != null) {
Integer x = after.get(a);
after.put(a, (x == null ? 0 : x) + c);
}
}
} }
private static class Stat { private static class Stat {
@ -476,10 +486,10 @@ public class MessageClassifier {
} }
private static class Chance { private static class Chance {
String clazz; private String clazz;
Double chance; private Double chance;
Chance(String clazz, Double chance) { private Chance(String clazz, Double chance) {
this.clazz = clazz; this.clazz = clazz;
this.chance = chance; this.chance = chance;
} }

Loading…
Cancel
Save