diff --git a/app/src/main/java/eu/faircode/email/AI.java b/app/src/main/java/eu/faircode/email/AI.java index a340ce79a9..46e709e33d 100644 --- a/app/src/main/java/eu/faircode/email/AI.java +++ b/app/src/main/java/eu/faircode/email/AI.java @@ -61,20 +61,27 @@ public class AI { messages.add(new OpenAI.Message(OpenAI.SYSTEM, new OpenAI.Content[]{ new OpenAI.Content(OpenAI.CONTENT_TEXT, systemPrompt)})); - messages.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ - new OpenAI.Content(OpenAI.CONTENT_TEXT, prompt == null ? defaultPrompt : prompt)})); + if (body instanceof Spannable && multimodal) { + messages.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ + new OpenAI.Content(OpenAI.CONTENT_TEXT, prompt == null ? defaultPrompt : prompt)})); - if (!TextUtils.isEmpty(body)) - if (body instanceof Spannable && multimodal) + if (!TextUtils.isEmpty(body)) messages.add(new OpenAI.Message(OpenAI.USER, OpenAI.Content.get((Spannable) body, id, context))); - else - messages.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ - new OpenAI.Content(OpenAI.CONTENT_TEXT, body.toString())})); - if (!TextUtils.isEmpty(reply)) + if (!TextUtils.isEmpty(reply)) + messages.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ + new OpenAI.Content(OpenAI.CONTENT_TEXT, reply)})); + } else { + List contents = new ArrayList<>(); + contents.add(prompt == null ? defaultPrompt : prompt); + if (!TextUtils.isEmpty(body)) + contents.add(body.toString()); + if (!TextUtils.isEmpty(reply)) + contents.add(reply); messages.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ - new OpenAI.Content(OpenAI.CONTENT_TEXT, reply)})); + new OpenAI.Content(OpenAI.CONTENT_TEXT, TextUtils.join("\n", contents))})); + } OpenAI.Message[] completions = OpenAI.completeChat(context, model, messages.toArray(new OpenAI.Message[0]), temperature, 1); @@ -182,17 +189,22 @@ public class AI { boolean multimodal = prefs.getBoolean("openai_multimodal", false); List input = new ArrayList<>(); - input.add(new OpenAI.Message(OpenAI.USER, - new OpenAI.Content[]{new OpenAI.Content(OpenAI.CONTENT_TEXT, - templatePrompt == null ? defaultPrompt : templatePrompt)})); if (multimodal) { + input.add(new OpenAI.Message(OpenAI.USER, + new OpenAI.Content[]{new OpenAI.Content(OpenAI.CONTENT_TEXT, + templatePrompt == null ? defaultPrompt : templatePrompt)})); + SpannableStringBuilder ssb = HtmlHelper.fromDocument(context, d, null, null); input.add(new OpenAI.Message(OpenAI.USER, OpenAI.Content.get(ssb, message.id, context))); - } else + } else { + List contents = new ArrayList<>(); + contents.add(templatePrompt == null ? defaultPrompt : templatePrompt); + contents.add(body); input.add(new OpenAI.Message(OpenAI.USER, new OpenAI.Content[]{ - new OpenAI.Content(OpenAI.CONTENT_TEXT, body)})); + new OpenAI.Content(OpenAI.CONTENT_TEXT, TextUtils.join("\n", contents))})); + } OpenAI.Message[] completions = OpenAI.completeChat(context, model, input.toArray(new OpenAI.Message[0]), temperature, 1);