File size: 940 Bytes
fec11df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pandas as pd
from .embedding import Embedder
from .retrieval import RetrievalDB
from .codegen import generate_code_with_context
class CodeGenPipeline:
    def __init__(self, parquet_path):
        self.df = pd.read_parquet(parquet_path)
        self.prompts = self.df["prompt"].tolist()
        self.solutions = self.df["canonical_solution"].tolist()
        self.embedder = Embedder()
        self.embeddings = self.embedder.encode(self.prompts, batch_size=32, show_progress_bar=True)
        self.retrieval_db = RetrievalDB(self.prompts, self.embeddings, self.solutions)

    def generate_code_from_prompt(self, user_prompt, k=1):
        query_emb = self.embedder.encode([user_prompt])[0]
        retrieved = self.retrieval_db.retrieve_similar_context(query_emb, k=k)[0]
        context = "\n\n".join([f"# Task:\n{r['prompt']}\n{r['solution']}" for r in retrieved])
        return generate_code_with_context(user_prompt, context)