File size: 20,168 Bytes
43b66f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
"""
Streamlit UI for fine-tuning code generation models.
"""
import streamlit as st
import pandas as pd
import numpy as np
import os
import time
from datetime import datetime
import torch
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path
import json
import uuid
import threading
from transformers import TrainingArguments
from datasets import Dataset

from components.fine_tuning.model_interface import (
    load_model_and_tokenizer,
    preprocess_code_dataset,
    setup_trainer,
    generate_code_comment,
    generate_code_from_comment,
    save_training_config,
    load_training_config
)

# Initialize training state
if 'training_run_id' not in st.session_state:
    st.session_state.training_run_id = None
if 'training_status' not in st.session_state:
    st.session_state.training_status = "idle"  # idle, running, completed, failed
if 'training_progress' not in st.session_state:
    st.session_state.training_progress = 0.0
if 'trained_model' not in st.session_state:
    st.session_state.trained_model = None
if 'trained_tokenizer' not in st.session_state:
    st.session_state.trained_tokenizer = None
if 'training_logs' not in st.session_state:
    st.session_state.training_logs = []
if 'fine_tuning_dataset' not in st.session_state:
    st.session_state.fine_tuning_dataset = None

# Directory for saving models
MODELS_DIR = Path("./fine_tuned_models")
MODELS_DIR.mkdir(exist_ok=True)

# Set for background training thread
training_thread = None

def render_dataset_preparation():
    """
    Render the dataset preparation interface.
    """
    st.markdown("### Dataset Preparation")

    # Dataset input options
    dataset_source = st.radio(
        "Choose dataset source",
        ["Upload CSV", "Manual Input", "Use Current Dataset"],
        help="Select how you want to provide your fine-tuning dataset"
    )

    if dataset_source == "Upload CSV":
        uploaded_file = st.file_uploader(
            "Upload fine-tuning dataset (CSV)",
            type=["csv"],
            help="CSV file with 'input' and 'target' columns"
        )

        if uploaded_file is not None:
            try:
                df = pd.read_csv(uploaded_file)

                # Check if required columns exist
                if "input" not in df.columns or "target" not in df.columns:
                    st.error("CSV must contain 'input' and 'target' columns.")
                    return

                # Preview dataset
                st.markdown("### Dataset Preview")
                st.dataframe(df.head(), use_container_width=True)

                # Dataset statistics
                st.markdown("### Dataset Statistics")
                col1, col2 = st.columns(2)
                with col1:
                    st.metric("Number of examples", len(df))
                with col2:
                    st.metric("Average input length", df["input"].astype(str).str.len().mean().round(1))

                # Save dataset
                if st.button("Use this dataset"):
                    st.session_state.fine_tuning_dataset = df
                    st.success(f"Dataset with {len(df)} examples loaded successfully!")

            except Exception as e:
                st.error(f"Error loading CSV: {str(e)}")

    elif dataset_source == "Manual Input":
        st.markdown("""
        Enter pairs of inputs and targets for fine-tuning. For code-to-comment tasks, the input is code and 
        the target is a comment. For comment-to-code tasks, the input is a comment and the target is code.
        """)

        # Container for input fields
        examples_container = st.container()

        # Default number of example fields
        if "num_examples" not in st.session_state:
            st.session_state.num_examples = 3

        # Add more examples button
        if st.button("Add another example"):
            st.session_state.num_examples += 1

        # Input fields for examples
        inputs = []
        targets = []

        with examples_container:
            for i in range(st.session_state.num_examples):
                st.markdown(f"### Example {i+1}")
                col1, col2 = st.columns(2)
                with col1:
                    input_text = st.text_area(f"Input {i+1}", key=f"input_{i}", height=150)
                    inputs.append(input_text)
                with col2:
                    target_text = st.text_area(f"Target {i+1}", key=f"target_{i}", height=150)
                    targets.append(target_text)

        # Create dataset from manual input
        if st.button("Create Dataset from Examples"):
            # Filter out empty examples
            valid_examples = [(inp, tgt) for inp, tgt in zip(inputs, targets) if inp.strip() and tgt.strip()]

            if valid_examples:
                df = pd.DataFrame(valid_examples, columns=["input", "target"])
                st.session_state.fine_tuning_dataset = df

                # Preview dataset
                st.markdown("### Dataset Preview")
                st.dataframe(df, use_container_width=True)
                st.success(f"Dataset with {len(df)} examples created successfully!")
            else:
                st.warning("No valid examples found. Please enter at least one input-target pair.")

    elif dataset_source == "Use Current Dataset":
        if st.session_state.dataset is None:
            st.warning("No dataset is currently loaded. Please upload or select a dataset first.")
        else:
            st.markdown("### Current Dataset")
            st.dataframe(st.session_state.dataset.head(), use_container_width=True)

            # Column selection
            col1, col2 = st.columns(2)
            with col1:
                input_col = st.selectbox("Select column for inputs", st.session_state.dataset.columns)
            with col2:
                target_col = st.selectbox("Select column for targets", st.session_state.dataset.columns)

            # Create fine-tuning dataset
            if st.button("Create Fine-Tuning Dataset"):
                df = st.session_state.dataset[[input_col, target_col]].copy()
                df.columns = ["input", "target"]

                # Verify data types and convert to string if necessary
                df["input"] = df["input"].astype(str)
                df["target"] = df["target"].astype(str)

                # Preview
                st.dataframe(df.head(), use_container_width=True)

                # Store dataset
                st.session_state.fine_tuning_dataset = df
                st.success(f"Fine-tuning dataset with {len(df)} examples created successfully!")

def render_model_training():
    """
    Render the model training interface.
    """
    st.markdown("### Model Training")

    # Check if dataset is available
    if st.session_state.fine_tuning_dataset is None:
        st.warning("Please prepare a dataset in the 'Dataset Preparation' tab first.")
        return

    # Model selection
    model_options = {
        "Salesforce/codet5-small": "CodeT5 Small (60M params)",
        "Salesforce/codet5-base": "CodeT5 Base (220M params)",
        "Salesforce/codet5-large": "CodeT5 Large (770M params)",
        "microsoft/codebert-base": "CodeBERT Base (125M params)",
        "facebook/bart-base": "BART Base (140M params)"
    }

    model_name = st.selectbox(
        "Select pre-trained model",
        list(model_options.keys()),
        format_func=lambda x: model_options[x],
        help="Select the base model for fine-tuning"
    )

    # Task type
    task_type = st.selectbox(
        "Select task type",
        ["Code to Comment", "Comment to Code"],
        help="Choose the direction of your task"
    )

    # Task prefix
    if task_type == "Code to Comment":
        task_prefix = "translate code to comment: "
    else:
        task_prefix = "translate comment to code: "

    # Hyperparameters
    st.markdown("### Training Hyperparameters")

    col1, col2 = st.columns(2)
    with col1:
        learning_rate = st.select_slider(
            "Learning Rate",
            options=[1e-6, 2e-6, 5e-6, 1e-5, 2e-5, 5e-5, 1e-4],
            value=5e-5,
            help="Step size for optimizer updates"
        )
        epochs = st.slider(
            "Epochs",
            min_value=1,
            max_value=20,
            value=3,
            help="Number of complete passes through the dataset"
        )
    with col2:
        batch_size = st.select_slider(
            "Batch Size",
            options=[1, 2, 4, 8, 16, 32],
            value=8,
            help="Number of examples processed in each training step"
        )
        max_input_length = st.slider(
            "Max Input Length (tokens)",
            min_value=64,
            max_value=512,
            value=256,
            help="Maximum length of input sequences"
        )

    # Advanced options
    with st.expander("Advanced Options"):
        col1, col2 = st.columns(2)
        with col1:
            weight_decay = st.select_slider(
                "Weight Decay",
                options=[0.0, 0.01, 0.05, 0.1],
                value=0.01,
                help="L2 regularization"
            )
            warmup_steps = st.slider(
                "Warmup Steps",
                min_value=0,
                max_value=1000,
                value=100,
                help="Steps for learning rate warmup"
            )
        with col2:
            max_target_length = st.slider(
                "Max Target Length (tokens)",
                min_value=64,
                max_value=512,
                value=256,
                help="Maximum length of target sequences"
            )
            gradient_accumulation = st.slider(
                "Gradient Accumulation Steps",
                min_value=1,
                max_value=16,
                value=1,
                help="Number of steps to accumulate gradients"
            )

    # Model output configuration
    st.markdown("### Model Output Configuration")
    model_name_custom = st.text_input(
        "Custom model name",
        value=f"{model_name.split('/')[-1]}-finetuned-{task_type.lower().replace(' ', '-')}",
        help="Name for your fine-tuned model"
    )

    # Training controls
    st.markdown("### Training Controls")

    # Check if training is in progress
    if st.session_state.training_status == "running":
        # Display progress
        st.progress(st.session_state.training_progress)

        # Show logs
        if st.session_state.training_logs:
            st.markdown("### Training Logs")
            log_text = "\n".join(st.session_state.training_logs[-10:])  # Show last 10 logs
            st.text_area("Latest logs", log_text, height=200, disabled=True)

        # Stop button
        if st.button("Stop Training"):
            # Logic to stop training thread
            st.session_state.training_status = "stopping"
            st.warning("Stopping training after current epoch completes...")

    elif st.session_state.training_status == "completed":
        st.success(f"Training completed! Model saved as: {model_name_custom}")

        # Show metrics if available
        if "training_metrics" in st.session_state:
            st.markdown("### Training Metrics")
            metrics_df = pd.DataFrame(st.session_state.training_metrics)
            st.line_chart(metrics_df)

        # Reset button
        if st.button("Start New Training"):
            st.session_state.training_status = "idle"
            st.session_state.training_progress = 0.0
            st.session_state.training_logs = []
            st.experimental_rerun()

    else:  # idle or failed
        # If previously failed, show error
        if st.session_state.training_status == "failed":
            st.error("Previous training failed. See logs for details.")
            if st.session_state.training_logs:
                st.text_area("Error logs", "\n".join(st.session_state.training_logs[-5:]), height=100, disabled=True)

        # Start training button
        if st.button("Start Training"):
            # Validate dataset
            if len(st.session_state.fine_tuning_dataset) < 5:
                st.warning("Dataset is very small. Consider adding more examples for better results.")

            # Set up training configuration
            training_config = {
                "model_name": model_name,
                "task_type": task_type,
                "task_prefix": task_prefix,
                "learning_rate": learning_rate,
                "epochs": epochs,
                "batch_size": batch_size,
                "max_input_length": max_input_length,
                "max_target_length": max_target_length,
                "weight_decay": weight_decay,
                "warmup_steps": warmup_steps,
                "gradient_accumulation": gradient_accumulation,
                "output_model_name": model_name_custom,
                "dataset_size": len(st.session_state.fine_tuning_dataset)
            }

            # Update session state
            st.session_state.training_status = "running"
            st.session_state.training_progress = 0.0
            st.session_state.training_logs = ["Training initialized..."]
            st.session_state.training_run_id = str(uuid.uuid4())

            # TODO: Start actual training process using transformers
            st.info("Training would start here with the Hugging Face transformers library")

            # For now, just simulate training progress
            st.session_state.training_progress = 0.1
            st.session_state.training_logs.append("Loaded model and tokenizer")
            st.session_state.training_logs.append("Preprocessing dataset...")

            # Rerun to update UI with progress
            st.experimental_rerun()

def render_model_testing():
    """
    Render the model testing interface.
    """
    st.markdown("### Test & Use Model")

    # Check if a model is trained/available
    if st.session_state.trained_model is None and st.session_state.training_status != "completed":
        # Look for saved models
        saved_models = list(MODELS_DIR.glob("*/"))
        if not saved_models:
            st.warning("No trained models available. Please train a model first.")
            return

        # Let user select a saved model
        model_options = [model.name for model in saved_models]
        selected_model = st.selectbox("Select a saved model", model_options)

        if st.button("Load Selected Model"):
            st.info(f"Loading model {selected_model}...")
            # TODO: Load model logic
            st.session_state.trained_model = "loaded"  # Placeholder
            st.session_state.trained_tokenizer = "loaded"  # Placeholder
            st.success("Model loaded successfully!")

    else:
        # Model is available for testing
        model_type = "Code to Comment" if "code-to-comment" in st.session_state.get("model_name", "") else "Comment to Code"

        st.markdown(f"### Testing {model_type} Generation")

        if model_type == "Code to Comment":
            input_text = st.text_area(
                "Enter code snippet",
                height=200,
                help="Enter a code snippet to generate a comment"
            )

            if st.button("Generate Comment"):
                if input_text:
                    with st.spinner("Generating comment..."):
                        # TODO: Replace with actual model inference
                        result = f"/* This code {input_text.split()[0:3]} ... */"
                        st.markdown("### Generated Comment")
                        st.code(result)
                else:
                    st.warning("Please enter a code snippet.")

        else:  # Comment to Code
            input_text = st.text_area(
                "Enter comment/description",
                height=150,
                help="Enter a description to generate code"
            )

            language = st.selectbox(
                "Programming language",
                ["Python", "JavaScript", "Java", "C++", "Go"]
            )

            if st.button("Generate Code"):
                if input_text:
                    with st.spinner("Generating code..."):
                        # TODO: Replace with actual model inference
                        result = f"def example_function():\n    # {input_text}\n    pass"
                        st.markdown("### Generated Code")
                        st.code(result, language=language.lower())
                else:
                    st.warning("Please enter a comment or description.")

        # Batch testing
        with st.expander("Batch Testing"):
            st.markdown("Upload a CSV file with test cases to evaluate your model.")

            test_file = st.file_uploader(
                "Upload test cases (CSV)",
                type=["csv"],
                help="CSV file with 'input' and 'expected' columns"
            )

            if test_file is not None:
                try:
                    test_df = pd.read_csv(test_file)
                    st.dataframe(test_df.head(), use_container_width=True)

                    if st.button("Run Batch Test"):
                        with st.spinner("Running tests..."):
                            # TODO: Actual batch inference
                            st.success("Batch testing completed!")

                            # Dummy results
                            results = pd.DataFrame({
                                "input": test_df["input"],
                                "expected": test_df.get("expected", [""] * len(test_df)),
                                "generated": ["Sample output " + str(i) for i in range(len(test_df))],
                                "match_score": np.random.uniform(0.5, 1.0, len(test_df))
                            })

                            st.dataframe(results, use_container_width=True)

                            # Metrics
                            st.markdown("### Evaluation Metrics")
                            col1, col2 = st.columns(2)
                            with col1:
                                st.metric("Average Match Score", f"{results['match_score'].mean():.2f}")
                            with col2:
                                st.metric("Tests Passed", f"{sum(results['match_score'] > 0.8)}/{len(results)}")

                except Exception as e:
                    st.error(f"Error loading test file: {str(e)}")

def render_finetune_ui():
    """
    Render the fine-tuning UI for code generation models.
    """
    st.markdown("<h2>Fine-Tune Code Generation Model</h2>", unsafe_allow_html=True)

    # Overview and instructions
    with st.expander("About Fine-Tuning", expanded=False):
        st.markdown("""
        ## Fine-Tuning a Code Generation Model

        This interface allows you to fine-tune pre-trained code generation models from Hugging Face
        on your custom dataset to adapt them to your specific coding style or task.

        ### How to use:
        1. **Prepare your dataset** - Upload a CSV file with 'input' and 'target' columns:
           - For code-to-comment: 'input' = code snippets, 'target' = corresponding comments
           - For comment-to-code: 'input' = comments, 'target' = corresponding code snippets

        2. **Configure training** - Set hyperparameters like learning rate, batch size, and epochs

        3. **Start fine-tuning** - Launch the training process and monitor progress

        4. **Test your model** - Once training is complete, test your model on new inputs

        ### Tips for better results:
        - Use a consistent format for your code snippets and comments
        - Start with a small dataset (50-100 examples) to verify the process
        - Try different hyperparameters to find the best configuration
        """)

    # Main UI with tabs
    tab1, tab2, tab3 = st.tabs(["Dataset Preparation", "Model Training", "Test & Use Model"])

    with tab1:
        render_dataset_preparation()

    with tab2:
        render_model_training()

    with tab3:
        render_model_testing()