From b075824fd76ad0fd9a55710cd1504036156b1f4e Mon Sep 17 00:00:00 2001 From: M66B Date: Fri, 10 Mar 2023 09:07:45 +0100 Subject: [PATCH] OpenAI: embeddings --- .../main/java/eu/faircode/email/OpenAI.java | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/app/src/main/java/eu/faircode/email/OpenAI.java b/app/src/main/java/eu/faircode/email/OpenAI.java index 4e42f004d4..5266f4c2a8 100644 --- a/app/src/main/java/eu/faircode/email/OpenAI.java +++ b/app/src/main/java/eu/faircode/email/OpenAI.java @@ -109,6 +109,20 @@ public class OpenAI { } } + static double[] getEmbedding(Context context, String text, String model) throws JSONException, IOException { + // https://platform.openai.com/docs/api-reference/embeddings + JSONObject jrequest = new JSONObject(); + jrequest.put("input", text); + jrequest.put("model", model == null ? "text-embedding-ada-002" : model); + JSONObject jresponse = call(context, "POST", "v1/embeddings", jrequest); + JSONObject jdata = jresponse.getJSONArray("data").getJSONObject(0); + JSONArray jembedding = jdata.getJSONArray("embedding"); + double[] result = new double[jembedding.length()]; + for (int i = 0; i < jembedding.length(); i++) + result[i] = jembedding.getDouble(i); + return result; + } + static Message[] completeChat(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException { // https://platform.openai.com/docs/guides/chat/introduction // https://platform.openai.com/docs/api-reference/chat/create @@ -238,4 +252,26 @@ public class OpenAI { return this.role + ": " + this.content; } } + + static class Embedding { + public static double getSimilarity(double[] v1, double[] v2) { + if (v1.length != v2.length) + throw new IllegalArgumentException("Invalid vector length=" + v1.length + "/" + v2.length); + double dotProduct = dotProduct(v1, v2); + double magV1 = magnitude(v1); + double magV2 = magnitude(v2); + return dotProduct / (magV1 * magV2); + } + + private static double dotProduct(double[] v1, double[] v2) { + float val = 0; + for (int i = 0; i <= v1.length - 1; i++) + val += v1[i] * v2[i]; + return val; + } + + private static double magnitude(double[] v) { + return Math.sqrt(dotProduct(v, v)); + } + } }