OpenAI: added moderation

pull/212/head
M66B 2 years ago
parent 2dc5946802
commit bc2eda5038

@ -36,6 +36,10 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
public class OpenAI {
static final String URI_ENDPOINT = "https://api.openai.com/";
@ -82,12 +86,39 @@ public class OpenAI {
grants.getDouble("total_granted"));
}
static void checkModeration(Context context, String text) throws JSONException, IOException {
// https://platform.openai.com/docs/api-reference/moderations/create
JSONObject jrequest = new JSONObject();
jrequest.put("input", text);
JSONObject jresponse = call(context, "POST", "v1/moderations", jrequest);
JSONArray jresults = jresponse.getJSONArray("results");
for (int i = 0; i < jresults.length(); i++) {
JSONObject jresult = jresults.getJSONObject(i);
if (jresult.getBoolean("flagged")) {
List<String> violations = new ArrayList<>();
JSONObject jcategories = jresult.getJSONObject("categories");
JSONObject jcategory_scores = jresult.getJSONObject("category_scores");
Iterator<String> keys = jcategories.keys();
while (keys.hasNext()) {
String key = keys.next();
Object value = jcategories.get(key);
if (Boolean.TRUE.equals(value)) {
Double score = (jcategories.has(key) ? jcategory_scores.getDouble(key) : null);
violations.add(key + (score == null ? "" : ":" + Math.round(score * 100) + "%"));
}
}
throw new IllegalArgumentException(TextUtils.join(", ", violations));
}
}
}
static Message[] completeChat(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException {
// https://platform.openai.com/docs/guides/chat/introduction
// https://platform.openai.com/docs/api-reference/chat/create
JSONArray jmessages = new JSONArray();
for (Message message : messages) {
checkModeration(context, message.content);
JSONObject jmessage = new JSONObject();
jmessage.put("role", message.role);
jmessage.put("content", message.content);
@ -121,6 +152,8 @@ public class OpenAI {
Uri uri = Uri.parse(URI_ENDPOINT).buildUpon().appendEncodedPath(path).build();
Log.i("OpenAI uri=" + uri);
long start = new Date().getTime();
URL url = new URL(uri.toString());
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod(method);
@ -161,6 +194,8 @@ public class OpenAI {
return new JSONObject(response);
} finally {
connection.disconnect();
long elapsed = new Date().getTime() - start;
Log.i("OpenAI elapsed=" + (elapsed / 1000f));
}
}

Loading…
Cancel
Save