File size: 3,684 Bytes
d9e5666
0dc3013
 
ad34200
d9e5666
ad34200
0dc3013
 
 
 
 
 
 
 
 
d9e5666
ad34200
 
 
 
 
 
 
 
 
 
 
 
0dc3013
 
 
 
ad34200
 
0dc3013
 
 
 
 
 
 
 
ad34200
0dc3013
d9e5666
0dc3013
 
 
 
 
d9e5666
 
0dc3013
 
 
 
 
 
 
 
 
 
ad34200
d9e5666
ad34200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9e5666
 
 
 
ad34200
 
d9e5666
ad34200
 
 
 
 
 
 
 
 
 
 
 
 
d9e5666
 
ad34200
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Initialize model and tokenizer
MODEL_ID = "erikbeltran/pydiff"
GGUF_FILE = "unsloth.Q4_K_M.gguf"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, gguf_file=GGUF_FILE)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, gguf_file=GGUF_FILE)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

def format_diff_response(response):
    """Format the response to look like a diff output"""
    lines = response.split('\n')
    formatted = []
    for line in lines:
        if line.startswith('+'):
            formatted.append(f'<span style="color: green">{line}</span>')
        elif line.startswith('-'):
            formatted.append(f'<span style="color: red">{line}</span>')
        else:
            formatted.append(line)
    return '<br>'.join(formatted)

def create_prompt(request, file_content, system_message):
    return f"""<system>{system_message}</system>
<request>{request}</request>
<file>
{file_content}
</file>"""

@spaces.GPU
def respond(request, file_content, system_message, max_tokens, temperature, top_p):
    prompt = create_prompt(request, file_content, system_message)
    
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate response with streaming
    response = ""
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    
    generation_kwargs = dict(
        inputs=inputs["input_ids"],
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        streamer=streamer,
    )
    
    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Yield formatted responses as they're generated
    for new_text in streamer:
        response += new_text
        yield format_diff_response(response)

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Code Review Assistant")
    
    with gr.Row():
        with gr.Column():
            request_input = gr.Textbox(
                label="Request",
                placeholder="Enter your request (e.g., 'fix the function', 'add error handling')",
                lines=3
            )
            file_input = gr.Code(
                label="File Content",
                language="python",
                lines=10
            )
        with gr.Column():
            output = gr.HTML(label="Diff Output")
    
    with gr.Accordion("Advanced Settings", open=False):
        system_msg = gr.Textbox(
            value="You are a code review assistant. Analyze the code and provide suggestions in diff format. Use '+' for additions and '-' for deletions.",
            label="System Message"
        )
        max_tokens = gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max Tokens"
        )
        temperature = gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        )
        top_p = gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p"
        )

    submit_btn = gr.Button("Submit")
    submit_btn.click(
        fn=respond,
        inputs=[
            request_input,
            file_input,
            system_msg,
            max_tokens,
            temperature,
            top_p
        ],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()