|
|
|
""" A collection of utility functions for data manipulation and formatting. """ |
|
import base64, io, re, logging |
|
from typing import Dict, List, Optional, Tuple, Any |
|
import numpy as np |
|
from PIL import Image |
|
from config import SEARCH_START, DIVIDER, REPLACE_END, GRADIO_SUPPORTED_LANGUAGES |
|
|
|
History = List[Tuple[Optional[str], Optional[str]]] |
|
Messages = List[Dict[str, Any]] |
|
|
|
def history_to_messages(history: History, system_prompt: str) -> Messages: |
|
messages: Messages = [{'role': 'system', 'content': system_prompt}] |
|
for user_msg, assistant_msg in history: |
|
if user_msg: messages.append({'role': 'user', 'content': user_msg}) |
|
if assistant_msg: messages.append({'role': 'assistant', 'content': assistant_msg}) |
|
return messages |
|
|
|
def history_to_chatbot_messages(history: History) -> Messages: |
|
messages: Messages = [] |
|
for user_msg, assistant_msg in history: |
|
display_text = "" |
|
if isinstance(user_msg, list): display_text = next((item.get("text", "") for item in user_msg if isinstance(item, dict) and item.get("type") == "text"), "") |
|
elif user_msg: display_text = user_msg |
|
if display_text: messages.append({"role": "user", "content": display_text}) |
|
if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) |
|
return messages |
|
|
|
def process_image_for_model(image_data: np.ndarray) -> str: |
|
pil_img = Image.fromarray(image_data) |
|
buffer = io.BytesIO() |
|
pil_img.save(buffer, format="PNG") |
|
return f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}" |
|
|
|
def remove_code_block(text: str) -> str: |
|
pattern = r'```[a-zA-Z]*\s*\n?(.*?)\n?```' |
|
match = re.search(pattern, text, re.DOTALL) |
|
return match.group(1).strip() if match else text.strip() |
|
|
|
def apply_search_replace_changes(original_code: str, changes_text: str) -> str: |
|
modified_code = original_code |
|
block_pattern = re.compile(rf"^{SEARCH_START}\n(.*?)\n^{DIVIDER}\n(.*?)\n^{REPLACE_END}", re.DOTALL | re.MULTILINE) |
|
for match in block_pattern.finditer(changes_text): |
|
search_content, replace_content = match.groups() |
|
if search_content in modified_code: modified_code = modified_code.replace(search_content, replace_content, 1) |
|
else: logging.warning(f"Search block not found: {search_content[:100]}") |
|
return modified_code |
|
|
|
def get_gradio_language(language: str) -> Optional[str]: |
|
return language if language in GRADIO_SUPPORTED_LANGUAGES else None |