Spaces:
Running
Running
""" | |
augmentor.py | |
This module implements a robust and scalable pipeline for finetuning data augmentation. | |
It supports generating augmented data in either OpenAI, Gemini, Mistral, or LLama fineβtuning JSONL format. | |
Users may optionally override metric thresholds and load existing examples from a JSONL file. | |
The LangChain Groq API key is now provided via the configuration rather than the .env file. | |
""" | |
import os | |
import json | |
import uuid | |
import logging | |
import re | |
import random | |
import ast | |
from typing import List, Dict, Any, Optional | |
# Removed dotenv load for GROQ_API_KEY since it is now provided in config | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("FinetuningAugmentor") | |
# Environment tokens (kept for HF_TOKEN if needed) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# GROQ_API_KEY will now be provided in the configuration | |
# ----------------------------- | |
# Data Models and Preprocessing | |
# ----------------------------- | |
from pydantic import BaseModel, field_validator, ValidationError | |
class AugmentationExample(BaseModel): | |
""" | |
An input/output example for augmentation. | |
""" | |
input_text: str | |
output_text: str | |
def non_empty(cls, v: str) -> str: | |
if not v.strip(): | |
raise ValueError("Text fields must be non-empty") | |
return v.strip() | |
class AugmentationConfig(BaseModel): | |
""" | |
Configuration for the augmentation process. | |
""" | |
target_model: str # e.g., "mixtral-8x7b-32768" or any Groq-supported model name | |
examples: List[AugmentationExample] | |
finetuning_goal: str | |
groq_api_key: str | |
system_message: Optional[str] = "Marv is a factual chatbot that is also sarcastic." | |
# Optional metric thresholds (if not provided, defaults are used) | |
min_semantic_similarity: Optional[float] = 0.80 | |
max_semantic_similarity: Optional[float] = 0.95 | |
min_diversity_score: Optional[float] = 0.70 | |
min_fluency_score: Optional[float] = 0.80 | |
def check_examples_length(cls, v: List[AugmentationExample]) -> List[AugmentationExample]: | |
if len(v) < 3: | |
raise ValueError("Provide at least 3 examples") | |
return v | |
class StandardExample(BaseModel): | |
""" | |
Standardized format for input examples. | |
""" | |
id: str | |
input_text: str | |
output_text: str | |
metadata: Dict[str, Any] = {} | |
def normalize_examples(examples: List[AugmentationExample]) -> List[StandardExample]: | |
""" | |
Normalize and standardize input examples. | |
""" | |
normalized = [] | |
for ex in examples: | |
norm_ex = StandardExample( | |
id=str(uuid.uuid4()), | |
input_text=ex.input_text.lower(), | |
output_text=ex.output_text.lower(), | |
metadata={"original_word_count": len(ex.input_text.split())} | |
) | |
normalized.append(norm_ex) | |
logger.info(f"Normalized {len(normalized)} examples.") | |
return normalized | |
# ----------------------------- | |
# Dynamic Strategy Selection | |
# ----------------------------- | |
def determine_augmentation_strategy(config: AugmentationConfig) -> Dict[str, Any]: | |
""" | |
Determine the augmentation strategy based on the finetuning goal. | |
""" | |
goal = config.finetuning_goal.lower() | |
strategy = {} | |
if any(word in goal for word in ["dialogue", "q&a", "conversation", "chat"]): | |
strategy["methods"] = ["llm_paraphrasing", "back_translation"] | |
else: | |
strategy["methods"] = ["eda_synonym_replacement", "llm_paraphrasing", "synthetic_noise"] | |
strategy["diversity_threshold"] = 0.7 | |
logger.info(f"Determined augmentation strategy: {strategy}") | |
return strategy | |
# ----------------------------- | |
# Helper Functions | |
# ----------------------------- | |
def extract_json(text: str) -> dict: | |
""" | |
Extract the first valid JSON object from a given text. | |
""" | |
json_pattern = re.compile(r'\{.*\}', re.DOTALL) | |
match = json_pattern.search(text) | |
if match: | |
json_str = match.group() | |
try: | |
return json.loads(json_str) | |
except json.JSONDecodeError as e: | |
raise ValueError(f"JSON decoding error: {e}") | |
else: | |
raise ValueError("No valid JSON found in text.") | |
def make_hashable(item: Any) -> Any: | |
""" | |
Recursively convert unhashable types (lists/dicts) into hashable tuples. | |
""" | |
if isinstance(item, (list, tuple)): | |
return tuple(make_hashable(i) for i in item) | |
elif isinstance(item, dict): | |
return tuple(sorted((k, make_hashable(v)) for k, v in item.items())) | |
else: | |
return item | |
def validate_jsonl_record(record: dict) -> bool: | |
""" | |
Validates that the record follows the required OpenAI format: | |
{"messages": [{"role": "system", "content": <str>}, | |
{"role": "user", "content": <non-empty str>}, | |
{"role": "assistant", "content": <non-empty str>}]} | |
""" | |
if "messages" not in record: | |
logger.error("Record missing 'messages' key.") | |
return False | |
messages = record["messages"] | |
if not isinstance(messages, list) or len(messages) != 3: | |
logger.error("Record 'messages' must be a list of 3 items.") | |
return False | |
expected_roles = ["system", "user", "assistant"] | |
for msg, role in zip(messages, expected_roles): | |
if not isinstance(msg, dict): | |
logger.error("Each message must be a dictionary.") | |
return False | |
if msg.get("role") != role: | |
logger.error(f"Expected role '{role}', but got '{msg.get('role')}'.") | |
return False | |
if "content" not in msg or not isinstance(msg["content"], str): | |
logger.error("Each message must have a string 'content' field.") | |
return False | |
if role in ["user", "assistant"] and not msg["content"].strip(): | |
logger.error(f"Message for role '{role}' has empty content.") | |
return False | |
return True | |
def get_text(value: Any) -> str: | |
""" | |
Ensure the value is returned as a string. | |
If it is a list, recursively return the first element. | |
If it is a dict and contains a "text" key, return that. | |
If it is a string that resembles a dict, try to parse it. | |
""" | |
if isinstance(value, list): | |
if value: | |
return get_text(value[0]) | |
return "" | |
elif isinstance(value, dict): | |
if "text" in value: | |
return str(value["text"]) | |
return str(value) | |
elif isinstance(value, str): | |
val = value.strip() | |
if val.startswith("{") and val.endswith("}"): | |
try: | |
parsed = ast.literal_eval(val) | |
if isinstance(parsed, dict) and "text" in parsed: | |
return str(parsed["text"]) | |
except Exception: | |
pass | |
return val | |
else: | |
return str(value) | |
# --- New helper: Fix content formatting --- | |
def fix_content(content: str) -> str: | |
""" | |
If the content appears to be a Python dict (using single quotes), try to | |
convert it to valid JSON (with double quotes). If parsing fails, return the original content. | |
""" | |
content = content.strip() | |
if content.startswith("{") and content.endswith("}") and "'" in content: | |
try: | |
parsed = ast.literal_eval(content) | |
return json.dumps(parsed) | |
except Exception as e: | |
logger.debug(f"Failed to fix content formatting: {e}") | |
return content | |
def flatten_content(content: str) -> str: | |
""" | |
If content (after fixing) is a JSON string representing a dictionary, | |
flatten it by joining its values into a single plain-text string. | |
""" | |
try: | |
parsed = json.loads(content) | |
if isinstance(parsed, dict): | |
# Join values in sorted order by key | |
values = [str(parsed[k]).strip() for k in sorted(parsed)] | |
return " ".join(values) | |
except Exception: | |
pass | |
return content | |
# ----------------------------- | |
# Augmentation Generation via LangChain Groq | |
# ----------------------------- | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate | |
def instantiate_groq_llm(model: str, groq_api_key: str) -> ChatGroq: | |
""" | |
Instantiate a ChatGroq LLM with the given model name and API key. | |
""" | |
return ChatGroq( | |
model=model, | |
temperature=0.7, | |
max_tokens=256, | |
timeout=30, | |
max_retries=2, | |
groq_api_key=groq_api_key | |
) | |
def generate_initial_augmentation(example: StandardExample, | |
config: AugmentationConfig, | |
strategy: Dict[str, Any]) -> dict: | |
""" | |
Generate an initial candidate augmentation using an LLM prompt chain. | |
""" | |
prompt_template = ChatPromptTemplate.from_messages([ | |
( | |
"system", | |
("You are a creative augmentation assistant that produces diverse yet semantically consistent " | |
"input/output pairs for finetuning tasks.") | |
), | |
( | |
"human", | |
( | |
"Augment the following example using the methods: {methods}. The finetuning goal is: {finetuning_goal}.\n" | |
"Ensure your output is in valid JSON format with keys 'augmented_input' and 'augmented_output'.\n" | |
"Input: {input_text}\n" | |
"Output: {output_text}\n" | |
"Return only the JSON response." | |
) | |
) | |
]) | |
prompt_vars = { | |
"methods": ", ".join(strategy["methods"]), | |
"finetuning_goal": config.finetuning_goal, | |
"input_text": example.input_text, | |
"output_text": example.output_text | |
} | |
chain = prompt_template | instantiate_groq_llm(config.target_model, config.groq_api_key) | |
ai_msg = chain.invoke(prompt_vars) | |
logger.info(f"Initial augmentation for {example.id}: {ai_msg.content.strip()}") | |
return extract_json(ai_msg.content.strip()) | |
def refine_augmentation(candidate: dict, | |
example: StandardExample, | |
config: AugmentationConfig, | |
strategy: Dict[str, Any]) -> dict: | |
""" | |
Refine a candidate augmentation using a second LLM prompt chain. | |
""" | |
refinement_template = ChatPromptTemplate.from_messages([ | |
( | |
"system", | |
"You are an expert data augmentation advisor who refines candidate augmentations to maximize semantic accuracy, diversity, and clarity." | |
), | |
( | |
"human", | |
( | |
"Review the candidate augmentation for the following input/output pair and refine it if needed.\n" | |
"Finetuning Goal: {finetuning_goal}\n" | |
"Original Input: {input_text}\n" | |
"Original Output: {output_text}\n" | |
"Candidate Augmentation: {candidate}\n" | |
"Return a refined augmentation in valid JSON format with keys 'augmented_input' and 'augmented_output' only." | |
) | |
) | |
]) | |
refinement_vars = { | |
"finetuning_goal": config.finetuning_goal, | |
"input_text": example.input_text, | |
"output_text": example.output_text, | |
"candidate": json.dumps(candidate) | |
} | |
chain = refinement_template | instantiate_groq_llm(config.target_model, config.groq_api_key) | |
ai_msg = chain.invoke(refinement_vars) | |
try: | |
refined = extract_json(ai_msg.content.strip()) | |
logger.info(f"Refined augmentation for {example.id}: {refined}") | |
return refined | |
except Exception as e: | |
logger.error(f"Refinement failed for {example.id}: {e}. Using original candidate.") | |
return candidate | |
def calculate_metrics(augmentation: dict, original: StandardExample) -> dict: | |
""" | |
Simulate metric calculations for the candidate augmentation. | |
""" | |
semantic_similarity = random.uniform(0.78, 0.97) | |
diversity_score = random.uniform(0.65, 0.9) | |
fluency_score = random.uniform(0.80, 0.95) | |
metrics = { | |
"semantic_similarity": semantic_similarity, | |
"diversity_score": diversity_score, | |
"fluency_score": fluency_score | |
} | |
logger.info(f"Metrics for candidate of {original.id}: {metrics}") | |
return metrics | |
def metrics_valid(metrics: dict, config: AugmentationConfig) -> bool: | |
""" | |
Validate metric thresholds using configuration values. | |
""" | |
if metrics["semantic_similarity"] < config.min_semantic_similarity or metrics["semantic_similarity"] > config.max_semantic_similarity: | |
return False | |
if metrics["diversity_score"] < config.min_diversity_score: | |
return False | |
if metrics["fluency_score"] < config.min_fluency_score: | |
return False | |
return True | |
def quality_check(augmentation: Dict[str, Any], config: AugmentationConfig) -> bool: | |
""" | |
Simulate an LLM-based QA check. | |
""" | |
qa_prompt = ( | |
f"Verify that the following augmentation preserves the intended meaning and style for the input/output pair " | |
f"given the finetuning goal '{config.finetuning_goal}':\n" | |
f"{augmentation['augmentation']}\n" | |
"Answer 'yes' if valid, otherwise 'no'." | |
) | |
logger.debug(f"QA Prompt: {qa_prompt}") | |
return True # Simulation: always passes | |
def generate_augmentations(normalized_examples: List[StandardExample], | |
config: AugmentationConfig, | |
strategy: Dict[str, Any], | |
target_count: int = 50) -> List[Dict[str, Any]]: | |
""" | |
Repeatedly generate candidate augmentations until at least target_count valid candidates are collected. | |
""" | |
augmented_candidates = [] | |
attempts = 0 | |
max_attempts = 100 # Safety valve | |
while len(augmented_candidates) < target_count and attempts < max_attempts: | |
for ex in normalized_examples: | |
try: | |
candidate = generate_initial_augmentation(ex, config, strategy) | |
refined_candidate = refine_augmentation(candidate, ex, config, strategy) | |
metrics = calculate_metrics(refined_candidate, ex) | |
if not metrics_valid(metrics, config): | |
logger.info(f"Candidate for {ex.id} rejected by metrics: {metrics}") | |
continue | |
if quality_check({"augmentation": refined_candidate}, config): | |
full_candidate = { | |
"original_id": ex.id, | |
"augmentation": refined_candidate, | |
"strategy": strategy, | |
"metrics": metrics | |
} | |
augmented_candidates.append(full_candidate) | |
logger.info(f"Accepted candidate for {ex.id} (Total accepted: {len(augmented_candidates)})") | |
if len(augmented_candidates) >= target_count: | |
break | |
except Exception as e: | |
logger.error(f"Error generating augmentation for {ex.id}: {e}") | |
attempts += 1 | |
if len(augmented_candidates) < target_count: | |
logger.warning(f"Only {len(augmented_candidates)} candidates generated after {attempts} attempts.") | |
return augmented_candidates | |
def deduplicate_augmentations(augmentations: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
""" | |
Remove duplicate augmentations based on hashable keys. | |
""" | |
seen = set() | |
unique_aug = [] | |
for aug in augmentations: | |
key = (make_hashable(aug["augmentation"].get("augmented_input")), | |
make_hashable(aug["augmentation"].get("augmented_output"))) | |
if key not in seen: | |
seen.add(key) | |
unique_aug.append(aug) | |
logger.info(f"Deduplicated to {len(unique_aug)} unique augmentations.") | |
return unique_aug | |
def format_for_openai(augmentations: List[Dict[str, Any]], system_message: str) -> str: | |
""" | |
Format augmentations in OpenAI fine-tuning JSONL format. | |
""" | |
output_lines = [] | |
sys_msg = system_message.strip() if system_message and system_message.strip() else "" | |
for aug in augmentations: | |
user_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_input", "")).strip())) | |
assistant_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_output", "")).strip())) | |
record = { | |
"messages": [ | |
{"role": "system", "content": sys_msg}, | |
{"role": "user", "content": user_val}, | |
{"role": "assistant", "content": assistant_val} | |
] | |
} | |
if validate_jsonl_record(record): | |
output_lines.append(json.dumps(record)) | |
else: | |
logger.error(f"Record validation failed: {record}") | |
logger.info(f"Formatted {len(output_lines)} records in OpenAI fine-tuning format.") | |
return "\n".join(output_lines) | |
def format_for_gemini(augmentations: List[Dict[str, Any]]) -> str: | |
""" | |
Format augmentations in Gemini fine-tuning JSONL format. | |
""" | |
output_lines = [] | |
for aug in augmentations: | |
user_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_input", "")).strip())) | |
assistant_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_output", "")).strip())) | |
record = { | |
"contents": [ | |
{"role": "user", "parts": [{"text": user_val}]}, | |
{"role": "model", "parts": [{"text": assistant_val}]} | |
] | |
} | |
if user_val and assistant_val: | |
output_lines.append(json.dumps(record)) | |
else: | |
logger.error(f"Gemini record validation failed: {record}") | |
logger.info(f"Formatted {len(output_lines)} records in Gemini fine-tuning format.") | |
return "\n".join(output_lines) | |
def format_for_common(augmentations: List[Dict[str, Any]]) -> str: | |
""" | |
Format augmentations in a common JSONL format for both Mistral and LLama. | |
""" | |
output_lines = [] | |
for aug in augmentations: | |
user_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_input", "")).strip())) | |
assistant_val = flatten_content(fix_content(get_text(aug["augmentation"].get("augmented_output", "")).strip())) | |
record = { | |
"messages": [ | |
{"role": "user", "content": user_val}, | |
{"role": "assistant", "content": assistant_val} | |
] | |
} | |
if user_val and assistant_val: | |
output_lines.append(json.dumps(record)) | |
else: | |
logger.error(f"Common format record validation failed: {record}") | |
logger.info(f"Formatted {len(output_lines)} records in common JSONL format for Mistral/LLama.") | |
return "\n".join(output_lines) | |
def format_for_mistral(augmentations: List[Dict[str, Any]]) -> str: | |
""" | |
Format augmentations in Mistral fine-tuning JSONL format. | |
Uses the common format. | |
""" | |
return format_for_common(augmentations) | |
def format_for_llama(augmentations: List[Dict[str, Any]]) -> str: | |
""" | |
Format augmentations in LLama fine-tuning JSONL format. | |
Uses the common format. | |
""" | |
return format_for_common(augmentations) | |
# ----------------------------- | |
# Optional: Load Existing Examples from JSONL | |
# ----------------------------- | |
def load_examples_from_file(file_path: str, format_type: str = "openai") -> List[AugmentationExample]: | |
""" | |
Load input/output pairs from a JSONL file. | |
""" | |
examples = [] | |
try: | |
with open(file_path, "r") as f: | |
for line in f: | |
line = line.strip() | |
if not line: | |
continue | |
record = json.loads(line) | |
if format_type.lower() == "openai": | |
msgs = record.get("messages", []) | |
if len(msgs) == 3: | |
user_text = msgs[1].get("content", "").strip() | |
assistant_text = msgs[2].get("content", "").strip() | |
if user_text and assistant_text: | |
examples.append(AugmentationExample(input_text=user_text, output_text=assistant_text)) | |
elif format_type.lower() == "gemini": | |
contents = record.get("contents", []) | |
if len(contents) >= 2: | |
user_parts = contents[0].get("parts", []) | |
model_parts = contents[1].get("parts", []) | |
user_text = get_text(user_parts[0]) if user_parts else "" | |
assistant_text = get_text(model_parts[0]) if model_parts else "" | |
if user_text and assistant_text: | |
examples.append(AugmentationExample(input_text=user_text, output_text=assistant_text)) | |
except Exception as e: | |
logger.error(f"Error loading examples from file: {e}") | |
logger.info(f"Loaded {len(examples)} examples from {file_path}") | |
return examples | |
# ----------------------------- | |
# Pipeline Class | |
# ----------------------------- | |
class FinetuningDataAugmentor: | |
""" | |
Encapsulates the entire augmentation pipeline. | |
""" | |
def __init__(self, config: AugmentationConfig): | |
self.config = config | |
self.normalized_examples = normalize_examples(config.examples) | |
self.strategy = determine_augmentation_strategy(config) | |
self.augmentations = [] | |
def run_augmentation(self, target_count: int = 50) -> List[Dict[str, Any]]: | |
""" | |
Generate candidate augmentations, deduplicate, and store results. | |
""" | |
logger.info("Starting augmentation generation via LangChain Groq...") | |
candidates = generate_augmentations(self.normalized_examples, self.config, self.strategy, target_count=target_count) | |
logger.info(f"Generated {len(candidates)} candidate augmentations before deduplication.") | |
unique_candidates = deduplicate_augmentations(candidates) | |
logger.info(f"{len(unique_candidates)} unique augmentations after deduplication.") | |
self.augmentations = unique_candidates | |
return unique_candidates | |
def get_formatted_output(self, format_type: str = "openai") -> str: | |
""" | |
Return the final augmented data in the desired finetuning format. | |
""" | |
fmt = format_type.lower() | |
if fmt == "openai": | |
return format_for_openai(self.augmentations, self.config.system_message) | |
elif fmt == "gemini": | |
return format_for_gemini(self.augmentations) | |
elif fmt == "mistral": | |
return format_for_mistral(self.augmentations) | |
elif fmt == "llama": | |
return format_for_llama(self.augmentations) | |
else: | |
logger.error(f"Unknown format type: {format_type}. Defaulting to OpenAI format.") | |
return format_for_openai(self.augmentations, self.config.system_message) | |
def save_to_file(self, filename: str = "train.jsonl") -> None: | |
""" | |
Save the formatted augmented data to a file. | |
""" | |
output = self.get_formatted_output() | |
with open(filename, "w") as f: | |
f.write(output) | |
logger.info(f"Final augmented data saved to {filename}") | |
def run_review_interface(self) -> None: | |
""" | |
Launch the interactive review interface. | |
""" | |
from streamlit import runtime | |
formatted_data = self.get_formatted_output() | |
launch_review_app(formatted_data) | |