from datasets import load_dataset
from openai import AsyncOpenAI
from tqdm import tqdm
import asyncio
import json
import os

client = AsyncOpenAI(api_key="no-need", base_url="http://localhost:8000/v1")

async def generate_answer(messages):
    try:
        response = await client.chat.completions.create(
            model="outputs/out-alpha",
            messages=messages,
            max_tokens=2048,
        )
        return response.choices[0].message.content
    except Exception as e:
        return 'error'

async def process_batch(questions, batch_num, all_qa_pairs):
    tasks = []
    for q in tqdm(questions, desc=f"Batch {batch_num}", leave=False):
        tasks.append(generate_answer(q))
    answers = await asyncio.gather(*tasks)
    
    # Create list of question-answer pairs and append to existing pairs
    for q, a in zip(questions, answers):
        q.append({'role': 'assistant', 'content': a})
        all_qa_pairs.append({"conversations": q})
    
    # Save all results after each batch
    with open('qa_pairs_all-alpha1b_2.json', 'w') as f:
        json.dump(all_qa_pairs, f, indent=2)
    
    return answers

async def main():
    dataset = load_dataset('qnguyen3/sft-r1')
    # Load existing QA pairs if file exists
    all_qa_pairs = []
    if os.path.exists('qa_pairs_all-alpha1b_2.json'):
        with open('qa_pairs_all-alpha1b_2.json', 'r') as f:
            all_qa_pairs = json.load(f)
    
    # Prepare questions
    question_list = []
    print("Preparing questions...")
    for i, item in tqdm(enumerate(dataset['train']), desc="Loading dataset"):
        if i >= 21600:
            question_list.append(item['messages'][:-1])
    
    # Process in batches of 200
    batch_size = 200
    for i in tqdm(range(0, len(question_list), batch_size), desc="Processing batches"):
        batch_questions = question_list[i:i+batch_size]
        batch_num = i // batch_size
        await process_batch(batch_questions, batch_num, all_qa_pairs)

if __name__ == "__main__":
    asyncio.run(main())