package eu.faircode.email; /* This file is part of FairEmail. FairEmail is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. FairEmail is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with FairEmail. If not, see . Copyright 2018-2024 by Marcel Bokhorst (M66B) */ import android.content.Context; import android.content.SharedPreferences; import android.graphics.Bitmap; import android.net.Uri; import android.text.Spannable; import android.text.TextUtils; import androidx.annotation.NonNull; import androidx.preference.PreferenceManager; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.net.HttpURLConnection; import java.net.URL; import java.util.ArrayList; import java.util.Date; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; public class OpenAI { static final String DEFAULT_MODEL = "gpt-4o"; static final float DEFAULT_TEMPERATURE = 0.5f; static final String DEFAULT_SUMMARY_PROMPT = "Summarize the following text:"; static final String ASSISTANT = "assistant"; static final String USER = "user"; // https://cookbook.openai.com/examples/gpt4o/introduction_to_gpt4o static final String CONTENT_TEXT = "text"; static final String CONTENT_IMAGE = "image_url"; private static final int MAX_OPENAI_LEN = 1000; // characters private static final int TIMEOUT = 45; // seconds private static final int SCALE2PIXELS = 1440; // medium static boolean isAvailable(Context context) { if (TextUtils.isEmpty(BuildConfig.OPENAI_ENDPOINT)) return false; SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context); boolean enabled = prefs.getBoolean("openai_enabled", false); String apikey = prefs.getString("openai_apikey", null); return (enabled && (!TextUtils.isEmpty(apikey) || !Objects.equals(getUri(context), BuildConfig.OPENAI_ENDPOINT))); } static void checkModeration(Context context, String text) throws JSONException, IOException { // https://platform.openai.com/docs/api-reference/moderations/create JSONObject jrequest = new JSONObject(); jrequest.put("input", text); JSONObject jresponse = call(context, "POST", "moderations", jrequest); JSONArray jresults = jresponse.getJSONArray("results"); for (int i = 0; i < jresults.length(); i++) { JSONObject jresult = jresults.getJSONObject(i); if (jresult.getBoolean("flagged")) { List violations = new ArrayList<>(); JSONObject jcategories = jresult.getJSONObject("categories"); JSONObject jcategory_scores = jresult.getJSONObject("category_scores"); Iterator keys = jcategories.keys(); while (keys.hasNext()) { String key = keys.next(); Object value = jcategories.get(key); if (Boolean.TRUE.equals(value)) { Double score = (jcategories.has(key) ? jcategory_scores.getDouble(key) : null); violations.add(key + (score == null ? "" : ":" + Math.round(score * 100) + "%")); } } throw new IllegalArgumentException(TextUtils.join(", ", violations)); } } } static double[] getEmbedding(Context context, String text, String model) throws JSONException, IOException { // https://platform.openai.com/docs/api-reference/embeddings JSONObject jrequest = new JSONObject(); jrequest.put("input", text); jrequest.put("model", model == null ? "text-embedding-ada-002" : model); JSONObject jresponse = call(context, "POST", "embeddings", jrequest); JSONObject jdata = jresponse.getJSONArray("data").getJSONObject(0); JSONArray jembedding = jdata.getJSONArray("embedding"); double[] result = new double[jembedding.length()]; for (int i = 0; i < jembedding.length(); i++) result[i] = jembedding.getDouble(i); return result; } static Message[] completeChat(Context context, String model, Message[] messages, Float temperature, int n) throws JSONException, IOException { // https://platform.openai.com/docs/guides/chat/introduction // https://platform.openai.com/docs/api-reference/chat/create JSONArray jmessages = new JSONArray(); for (Message message : messages) { JSONObject jmessage = new JSONObject(); jmessage.put("role", message.role); if (message.content.length == 1 && CONTENT_TEXT.equals(message.content[0].type)) jmessage.put("content", message.content[0].content); else { JSONArray jcontents = new JSONArray(); for (Content content : message.content) { JSONObject jcontent = new JSONObject(); jcontent.put("type", content.type); if (CONTENT_IMAGE.equals(content.type)) { JSONObject jimage = new JSONObject(); jimage.put("url", content.content); jcontent.put(content.type, jimage); } else jcontent.put(content.type, content.content); jcontents.put(jcontent); } jmessage.put("content", jcontents); } jmessages.put(jmessage); } JSONObject jquestion = new JSONObject(); jquestion.put("model", model); jquestion.put("messages", jmessages); if (temperature != null) jquestion.put("temperature", temperature); jquestion.put("n", n); JSONObject jresponse = call(context, "POST", "chat/completions", jquestion); JSONArray jchoices = jresponse.getJSONArray("choices"); Message[] choices = new Message[jchoices.length()]; for (int i = 0; i < jchoices.length(); i++) { JSONObject jchoice = jchoices.getJSONObject(i); JSONObject jmessage = jchoice.getJSONObject("message"); choices[i] = new Message(jmessage.getString("role"), new Content[]{new Content(CONTENT_TEXT, jmessage.getString("content"))}); } return choices; } private static String getUri(Context context) { SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context); String endpoint = prefs.getString("openai_uri", BuildConfig.OPENAI_ENDPOINT); if (!endpoint.endsWith("/")) endpoint += "/"; return endpoint; } private static JSONObject call(Context context, String method, String path, JSONObject args) throws JSONException, IOException { SharedPreferences prefs = PreferenceManager.getDefaultSharedPreferences(context); String apikey = prefs.getString("openai_apikey", null); // https://platform.openai.com/docs/api-reference/introduction Uri uri = Uri.parse(getUri(context)).buildUpon().appendEncodedPath(path).build(); Log.i("OpenAI uri=" + uri); long start = new Date().getTime(); URL url = new URL(uri.toString()); HttpURLConnection connection = (HttpURLConnection) url.openConnection(); connection.setRequestMethod(method); connection.setDoOutput(args != null); connection.setDoInput(true); connection.setReadTimeout(TIMEOUT * 1000); connection.setConnectTimeout(TIMEOUT * 1000); ConnectionHelper.setUserAgent(context, connection); connection.setRequestProperty("Accept", "application/json"); connection.setRequestProperty("Content-Type", "application/json"); connection.setRequestProperty("Authorization", "Bearer " + apikey); connection.connect(); try { if (args != null) { String json = args.toString(); Log.i("OpenAI request=" + json); connection.getOutputStream().write(json.getBytes()); } int status = connection.getResponseCode(); if (status != HttpURLConnection.HTTP_OK) { // https://platform.openai.com/docs/guides/error-codes/api-errors String error = "Error " + status + ": " + connection.getResponseMessage(); try { // HTTP 429 // { // "error": { // "message": "You exceeded your current quota, please check your plan and billing details.", // "type": "insufficient_quota", // "param": null, // "code": null // } //} InputStream is = connection.getErrorStream(); if (is != null) { String err = Helper.readStream(is); if (BuildConfig.DEBUG) error += "\n" + err; else { Log.w(new Throwable(err)); try { JSONObject jerror = new JSONObject(err).getJSONObject("error"); error += "\n" + jerror.getString("type") + ": " + jerror.getString("message"); } catch (JSONException ignored) { error += "\n" + err; } } } } catch (Throwable ex) { Log.w(ex); } throw new IOException(error); } String response = Helper.readStream(connection.getInputStream()); Log.i("OpenAI response=" + response); try { // https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers for (Map.Entry> entries : connection.getHeaderFields().entrySet()) { String key = entries.getKey(); if (key != null && key.startsWith("x-ratelimit")) for (String value : entries.getValue()) Log.i("OpenAI", key + "=" + value); } } catch (Throwable ex) { Log.w(ex); } return new JSONObject(response); } finally { connection.disconnect(); long elapsed = new Date().getTime() - start; Log.i("OpenAI elapsed=" + (elapsed / 1000f)); } } static class Content { private String type; private String content; public Content(String type, String content) { this.type = type; this.content = content; } public String getType() { return this.type; } public String getContent() { return this.content; } static Content[] get(Spannable ssb, long id, boolean multimodal, Context context) { DB db = DB.getInstance(context); List contents = new ArrayList<>(); int start = 0; while (start < ssb.length()) { int end = ssb.nextSpanTransition(start, ssb.length(), ImageSpanEx.class); String text = ssb.subSequence(start, end).toString().trim() .replace("\u00a0", "") .replace("\ufffc", ""); Log.i("OpenAI content " + start + "..." + end + " text=[" + Helper.getPrintableString(text, true) + "]"); if (!TextUtils.isEmpty(text)) contents.add(new OpenAI.Content(OpenAI.CONTENT_TEXT, text)); if (end < ssb.length()) { ImageSpanEx[] spans = ssb.getSpans(end, end, ImageSpanEx.class); Log.i("OpenAI images=" + (spans == null ? null : spans.length)); if (spans != null && spans.length == 1) { int e = ssb.getSpanEnd(spans[0]); if (multimodal) { String url = null; String src = spans[0].getSource(); Log.i("OpenAI image url=" + src); if (src != null && src.startsWith("cid:")) { String cid = '<' + src.substring(4) + '>'; EntityAttachment attachment = db.attachment().getAttachment(id, cid); if (attachment != null && attachment.available) { File file = attachment.getFile(context); try (InputStream is = new FileInputStream(file)) { Bitmap bm = ImageHelper.getScaledBitmap(is, null, null, SCALE2PIXELS); Helper.ByteArrayInOutStream bos = new Helper.ByteArrayInOutStream(); bm.compress(Bitmap.CompressFormat.PNG, 90, bos); url = ImageHelper.getDataUri(bos.getInputStream(), "image/png"); } catch (Throwable ex) { Log.w(ex); } } } else url = src; if (url != null) contents.add(new OpenAI.Content(OpenAI.CONTENT_IMAGE, url)); } end = e; } } start = (end > start ? end : start + 1); } return contents.toArray(new OpenAI.Content[0]); } } static class Message { private final String role; // system, user, assistant private final Content[] content; public Message(String role, Content[] content) { this.role = role; this.content = content; } public String getRole() { return this.role; } public Content[] getContent() { return this.content; } @NonNull @Override public String toString() { StringBuilder sb = new StringBuilder(); if (this.content != null) for (Content c : this.content) { if (sb.length() > 0) sb.append(", "); sb.append(c.type).append(':').append(c.content); } return this.role + ": " + sb; } } static class Embedding { public static double getSimilarity(double[] v1, double[] v2) { if (v1.length != v2.length) throw new IllegalArgumentException("Invalid vector length=" + v1.length + "/" + v2.length); double dotProduct = dotProduct(v1, v2); double magV1 = magnitude(v1); double magV2 = magnitude(v2); return dotProduct / (magV1 * magV2); } private static double dotProduct(double[] v1, double[] v2) { float val = 0; for (int i = 0; i <= v1.length - 1; i++) val += v1[i] * v2[i]; return val; } private static double magnitude(double[] v) { return Math.sqrt(dotProduct(v, v)); } } }