File size: 3,436 Bytes
310fc8c
ad773e5
b0f6b9b
6ac13b6
ad773e5
 
b0f6b9b
ad773e5
 
 
 
 
 
310fc8c
ad773e5
310fc8c
 
 
 
 
 
 
 
 
ad773e5
310fc8c
 
 
 
 
 
 
68c5b4c
310fc8c
 
 
 
 
 
 
 
68c5b4c
310fc8c
 
 
 
 
 
 
d93899c
 
310fc8c
6ac13b6
0ba7801
6ac13b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d93899c
6ac13b6
25caed5
a6938ff
6ac13b6
 
 
a6938ff
 
6ac13b6
 
 
 
a6938ff
 
6ac13b6
 
 
a6938ff
6ac13b6
 
0ba7801
310fc8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from huggingface_hub import InferenceClient
import gradio as gr
from fpdf import FPDF
import docx

css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {
    visibility: hidden
}
'''

client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

def format_prompt(message, history, system_prompt=None):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    if system_prompt:
        prompt += f"[SYS] {system_prompt} [/SYS]"
    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(
    prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history, system_prompt)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        # Clean up </s> tags from the generated output
        output = output.replace("</s>", "")
        yield output
    return output

# Save the generated content to a file
def save_file(content, filename, file_format):
    if file_format == "pdf":
        pdf = FPDF()
        pdf.add_page()
        pdf.set_auto_page_break(auto=True, margin=15)
        pdf.set_font("Arial", size=12)
        for line in content.split("\n"):
            pdf.multi_cell(0, 10, line)
        pdf.output(f"{filename}.pdf")
        return f"{filename}.pdf"
    elif file_format == "docx":
        doc = docx.Document()
        doc.add_paragraph(content)
        doc.save(f"{filename}.docx")
        return f"{filename}.docx"
    elif file_format == "txt":
        with open(f"{filename}.txt", "w") as f:
            f.write(content)
        return f"{filename}.txt"
    else:
        raise ValueError("Unsupported file format")

def generate_and_save(prompt, history, filename="output", file_format="pdf", system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0):
    generated_text = ""
    for output in generate(prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty):
        generated_text = output
    generated_text = generated_text.replace("</s>", "")
    saved_file = save_file(generated_text, filename, file_format)
    return generated_text, history + [(prompt, generated_text)], saved_file
    
demo = gr.Interface(
    fn=generate_and_save,
    inputs=[
        gr.Textbox(placeholder="Type your message here...", label="Chatbot", line=4),
        gr.State(value=[]), 
        gr.Textbox(placeholder="Filename (default: output)", label="Filename", value="output"),
        gr.Radio(["pdf", "docx", "txt"], label="File Format", value="pdf"),
    ],
    outputs=[
        gr.Textbox(label="Generated Text", line=4),
        gr.State(value=[]), 
        gr.File(label="Download File")
    ],
    css=css,
    title="GRAB DOC",
    theme="bethecloud/storj_theme"
)

demo.queue().launch(show_api=False)