File size: 4,050 Bytes
d5aaf43
 
5f8b9a9
8ed8368
 
 
 
 
d5aaf43
 
159ccaa
 
d5aaf43
 
 
 
 
 
159ccaa
d5aaf43
 
 
 
159ccaa
 
d5aaf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9547b0
474e3da
159ccaa
a9547b0
 
159ccaa
a9547b0
 
787059d
a9547b0
 
 
 
 
 
 
 
 
 
787059d
a9547b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5aaf43
159ccaa
d5aaf43
159ccaa
d5aaf43
8ed8368
d5aaf43
8ed8368
d5aaf43
8ed8368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5aaf43
159ccaa
 
 
 
 
 
d5aaf43
159ccaa
d5aaf43
159ccaa
d5aaf43
159ccaa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">AI Lawyer</h1>
<p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/StevenChen16/llama3-8b-Lawyer"><b>StevenChen16/llama3-8b-Lawyer</b></a>. This model is fine-tuned to provide legal information and assist with a wide range of legal questions. Feel free to ask anything!</p>
</div>
'''

LICENSE = """
<p/>
---
Built with model "StevenChen16/Llama3-8B-Lawyer", based on "meta-llama/Meta-Llama-3-8B"
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">AI Lawyer</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything about US and Canada law...</p>
</div>
"""

css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("StevenChen16/llama3-8b-Lawyer")
model = AutoModelForCausalLM.from_pretrained("StevenChen16/llama3-8b-Lawyer", device_map="auto")
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("")
]

def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int):
    conversation = []
    for user, assistant in history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )
    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')

with gr.Blocks(css=css) as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    gr.ChatInterface(
        fn=chat_llama3_8b,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False ),
            ],
        examples=[
            ['What are the key differences between a sole proprietorship and a partnership?'],
            ['What legal steps should I take if I want to start a business in the US?'],
            ['Can you explain the concept of "duty of care" in negligence law?'],
            ['What are the legal requirements for obtaining a patent in Canada?'],
            ['How can I protect my intellectual property when sharing my idea with potential investors?']
        ],
        cache_examples=False,
    )
    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.launch(share=True)