|
|
|
@ -109,6 +109,20 @@ public class OpenAI {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static double[] getEmbedding(Context context, String text, String model) throws JSONException, IOException {
|
|
|
|
|
// https://platform.openai.com/docs/api-reference/embeddings
|
|
|
|
|
JSONObject jrequest = new JSONObject();
|
|
|
|
|
jrequest.put("input", text);
|
|
|
|
|
jrequest.put("model", model == null ? "text-embedding-ada-002" : model);
|
|
|
|
|
JSONObject jresponse = call(context, "POST", "v1/embeddings", jrequest);
|
|
|
|
|
JSONObject jdata = jresponse.getJSONArray("data").getJSONObject(0);
|
|
|
|
|
JSONArray jembedding = jdata.getJSONArray("embedding");
|
|
|
|
|
double[] result = new double[jembedding.length()];
|
|
|
|
|
for (int i = 0; i < jembedding.length(); i++)
|
|
|
|
|
result[i] = jembedding.getDouble(i);
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Message[] completeChat(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException {
|
|
|
|
|
// https://platform.openai.com/docs/guides/chat/introduction
|
|
|
|
|
// https://platform.openai.com/docs/api-reference/chat/create
|
|
|
|
@ -238,4 +252,26 @@ public class OpenAI {
|
|
|
|
|
return this.role + ": " + this.content;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static class Embedding {
|
|
|
|
|
public static double getSimilarity(double[] v1, double[] v2) {
|
|
|
|
|
if (v1.length != v2.length)
|
|
|
|
|
throw new IllegalArgumentException("Invalid vector length=" + v1.length + "/" + v2.length);
|
|
|
|
|
double dotProduct = dotProduct(v1, v2);
|
|
|
|
|
double magV1 = magnitude(v1);
|
|
|
|
|
double magV2 = magnitude(v2);
|
|
|
|
|
return dotProduct / (magV1 * magV2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private static double dotProduct(double[] v1, double[] v2) {
|
|
|
|
|
float val = 0;
|
|
|
|
|
for (int i = 0; i <= v1.length - 1; i++)
|
|
|
|
|
val += v1[i] * v2[i];
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private static double magnitude(double[] v) {
|
|
|
|
|
return Math.sqrt(dotProduct(v, v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|