From 1cbebf78700379976ed81fc0807b1b6f56a577bb Mon Sep 17 00:00:00 2001 From: M66B Date: Tue, 14 May 2024 09:21:38 +0200 Subject: [PATCH] OpenAI: prepare multi-modal support --- .../eu/faircode/email/FragmentCompose.java | 27 +++++++--- .../email/FragmentDialogSummarize.java | 18 ++++--- .../main/java/eu/faircode/email/OpenAI.java | 53 ++++++++++++++++--- 3 files changed, 77 insertions(+), 21 deletions(-) diff --git a/app/src/main/java/eu/faircode/email/FragmentCompose.java b/app/src/main/java/eu/faircode/email/FragmentCompose.java index 1411d4de46..8a9b2709fe 100644 --- a/app/src/main/java/eu/faircode/email/FragmentCompose.java +++ b/app/src/main/java/eu/faircode/email/FragmentCompose.java @@ -2690,11 +2690,13 @@ public class FragmentCompose extends FragmentBase { Document parsed = JsoupEx.parse(inreplyto.get(0).getFile(context)); Document document = HtmlHelper.sanitizeView(context, parsed, false); Spanned spanned = HtmlHelper.fromDocument(context, document, null, null); - result.add(new OpenAI.Message(role, OpenAI.truncateParagraphs(spanned.toString()))); + result.add(new OpenAI.Message(role, new OpenAI.Content[]{ + new OpenAI.Content(OpenAI.CONTENT_TEXT, OpenAI.truncateParagraphs(spanned.toString()))})); } if (!TextUtils.isEmpty(body)) - result.add(new OpenAI.Message(OpenAI.ASSISTANT, OpenAI.truncateParagraphs(body))); + result.add(new OpenAI.Message(OpenAI.ASSISTANT, new OpenAI.Content[]{ + new OpenAI.Content(OpenAI.CONTENT_TEXT, OpenAI.truncateParagraphs(body))})); if (result.size() == 0) return null; @@ -2706,7 +2708,9 @@ public class FragmentCompose extends FragmentBase { if (moderation) for (OpenAI.Message message : result) - OpenAI.checkModeration(context, message.getContent()); + for (OpenAI.Content content : message.getContent()) + if (OpenAI.CONTENT_TEXT.equals(content.getContent())) + OpenAI.checkModeration(context, content.getContent()); OpenAI.Message[] completions = OpenAI.completeChat(context, model, result.toArray(new OpenAI.Message[0]), temperature, 1); @@ -2719,8 +2723,15 @@ public class FragmentCompose extends FragmentBase { if (messages == null || messages.length == 0) return; - String text = messages[0].getContent() - .replaceAll("^\\n+", "").replaceAll("\\n+$", ""); + StringBuilder sb = new StringBuilder(); + for (OpenAI.Message message : messages) + for (OpenAI.Content content : message.getContent()) + if (OpenAI.CONTENT_TEXT.equals(content.getType())) { + if (sb.length() > 0) + sb.append('\n'); + sb.append(content.getContent().replaceAll("^\\n+", "").replaceAll("\\n+$", "")); + } + Editable edit = etBody.getText(); int start = etBody.getSelectionStart(); @@ -2738,10 +2749,10 @@ public class FragmentCompose extends FragmentBase { if (index > 0 && edit.charAt(index - 1) != '\n') edit.insert(index++, "\n"); - edit.insert(index, text + "\n"); - etBody.setSelection(index + text.length() + 1); + edit.insert(index, sb + "\n"); + etBody.setSelection(index + sb.length() + 1); - StyleHelper.markAsInserted(edit, index, index + text.length() + 1); + StyleHelper.markAsInserted(edit, index, index + sb.length() + 1); if (args.containsKey("used") && args.containsKey("granted")) { double used = args.getDouble("used"); diff --git a/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java b/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java index af8dd48dcc..6d037dfbff 100644 --- a/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java +++ b/app/src/main/java/eu/faircode/email/FragmentDialogSummarize.java @@ -108,16 +108,20 @@ public class FragmentDialogSummarize extends FragmentDialogBase { String prompt = prefs.getString("openai_summarize", OpenAI.SUMMARY_PROMPT); List result = new ArrayList<>(); - result.add(new OpenAI.Message(OpenAI.ASSISTANT, prompt)); - result.add(new OpenAI.Message(OpenAI.USER, text)); + result.add(new OpenAI.Message(OpenAI.ASSISTANT, + new OpenAI.Content[]{new OpenAI.Content(OpenAI.CONTENT_TEXT, prompt)})); + result.add(new OpenAI.Message(OpenAI.USER, + new OpenAI.Content[]{new OpenAI.Content(OpenAI.CONTENT_TEXT, text)})); OpenAI.Message[] completions = OpenAI.completeChat(context, model, result.toArray(new OpenAI.Message[0]), temperature, 1); StringBuilder sb = new StringBuilder(); - for (OpenAI.Message completion : completions) { - if (sb.length() != 0) - sb.append('\n'); - sb.append(completion.getContent()); - } + for (OpenAI.Message completion : completions) + for (OpenAI.Content content : completion.getContent()) + if (OpenAI.CONTENT_TEXT.equals(content.getType())) { + if (sb.length() != 0) + sb.append('\n'); + sb.append(content.getContent()); + } return sb.toString(); } else if (Gemini.isAvailable(context)) { String model = prefs.getString("gemini_model", "gemini-pro"); diff --git a/app/src/main/java/eu/faircode/email/OpenAI.java b/app/src/main/java/eu/faircode/email/OpenAI.java index d079b0dcbc..724a564dd8 100644 --- a/app/src/main/java/eu/faircode/email/OpenAI.java +++ b/app/src/main/java/eu/faircode/email/OpenAI.java @@ -46,6 +46,10 @@ public class OpenAI { static final String USER = "user"; static final String SUMMARY_PROMPT = "Summarize the following text:"; + // https://cookbook.openai.com/examples/gpt4o/introduction_to_gpt4o + static final String CONTENT_TEXT = "text"; + static final String CONTENT_IMAGE = "image_url"; + private static final int MAX_OPENAI_LEN = 1000; // characters private static final int TIMEOUT = 45; // seconds @@ -108,7 +112,18 @@ public class OpenAI { for (Message message : messages) { JSONObject jmessage = new JSONObject(); jmessage.put("role", message.role); - jmessage.put("content", message.content); + + JSONArray jcontents = new JSONArray(); + + for (Content content : message.content) { + JSONObject jcontent = new JSONObject(); + jcontent.put("type", content.type); + jcontent.put(content.type, content.content); + jcontents.put(jcontent); + } + + jmessage.put("content", jcontents); + jmessages.put(jmessage); } @@ -125,7 +140,8 @@ public class OpenAI { for (int i = 0; i < jchoices.length(); i++) { JSONObject jchoice = jchoices.getJSONObject(i); JSONObject jmessage = jchoice.getJSONObject("message"); - choices[i] = new Message(jmessage.getString("role"), jmessage.getString("content")); + choices[i] = new Message(jmessage.getString("role"), + new Content[]{new Content(CONTENT_TEXT, jmessage.getString("content"))}); } return choices; @@ -232,11 +248,29 @@ public class OpenAI { return sb.toString(); } + static class Content { + private String type; + private String content; + + public Content(String type, String content) { + this.type = type; + this.content = content; + } + + public String getType() { + return this.type; + } + + public String getContent() { + return this.content; + } + } + static class Message { private final String role; // system, user, assistant - private final String content; + private final Content[] content; - public Message(String role, String content) { + public Message(String role, Content[] content) { this.role = role; this.content = content; } @@ -245,14 +279,21 @@ public class OpenAI { return this.role; } - public String getContent() { + public Content[] getContent() { return this.content; } @NonNull @Override public String toString() { - return this.role + ": " + this.content; + StringBuilder sb = new StringBuilder(); + if (this.content != null) + for (Content c : this.content) { + if (sb.length() > 0) + sb.append(", "); + sb.append(c.type).append(':').append(c.content); + } + return this.role + ": " + sb; } }