File size: 3,569 Bytes
310fc8c
ad773e5
b0f6b9b
6ac13b6
ad773e5
 
b0f6b9b
ad773e5
 
 
 
 
 
310fc8c
ad773e5
310fc8c
 
 
 
 
 
 
 
 
ad773e5
6ac13b6
310fc8c
 
 
 
 
 
 
68c5b4c
310fc8c
 
 
 
 
 
 
 
68c5b4c
310fc8c
 
 
 
 
 
 
d93899c
 
310fc8c
6ac13b6
0ba7801
6ac13b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d93899c
 
6ac13b6
25caed5
6ac13b6
 
 
 
 
 
 
 
 
 
 
 
25caed5
6ac13b6
 
 
 
 
 
0ba7801
310fc8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from huggingface_hub import InferenceClient
import gradio as gr
from fpdf import FPDF
import docx

css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {
    visibility: hidden
}
'''

client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

def format_prompt(message, history, system_prompt=None):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    if system_prompt:
        prompt += f"[SYS] {system_prompt} [/SYS]"
    prompt += f"[INST] {message} [/INST]"
    return prompt

# Generate text
def generate(
    prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history, system_prompt)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        # Clean up </s> tags from the generated output
        output = output.replace("</s>", "")
        yield output
    return output

# Save the generated content to a file
def save_file(content, filename, file_format):
    if file_format == "pdf":
        pdf = FPDF()
        pdf.add_page()
        pdf.set_auto_page_break(auto=True, margin=15)
        pdf.set_font("Arial", size=12)
        for line in content.split("\n"):
            pdf.multi_cell(0, 10, line)
        pdf.output(f"{filename}.pdf")
        return f"{filename}.pdf"
    elif file_format == "docx":
        doc = docx.Document()
        doc.add_paragraph(content)
        doc.save(f"{filename}.docx")
        return f"{filename}.docx"
    elif file_format == "txt":
        with open(f"{filename}.txt", "w") as f:
            f.write(content)
        return f"{filename}.txt"
    else:
        raise ValueError("Unsupported file format")

# Combine generate and save file functions
def generate_and_save(prompt, history, filename="output", file_format="pdf", system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0):
    generated_text = ""
    for output in generate(prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty):
        generated_text = output
    # Ensure </s> tags are removed from the final output
    generated_text = generated_text.replace("</s>", "")
    saved_file = save_file(generated_text, filename, file_format)
    return generated_text, history + [(prompt, generated_text)], saved_file

# Create Gradio interface
demo = gr.Interface(
    fn=generate_and_save,
    inputs=[
        gr.Textbox(placeholder="Type your message here...", label="Prompt"),
        gr.State(value=[]),  # history
        gr.Textbox(placeholder="Filename (default: output)", label="Filename", value="output"),
        gr.Radio(["pdf", "docx", "txt"], label="File Format", value="pdf"),
    ],
    outputs=[
        gr.Textbox(label="Generated Text"),
        gr.State(value=[]),  # history
        gr.File(label="Download File")
    ],
    css=css,
    title="",
    theme="bethecloud/storj_theme"
)

demo.queue().launch(show_api=False)