|
import gradio as gr |
|
from transformers import pipeline |
|
import torch |
|
import subprocess |
|
import spaces |
|
|
|
|
|
@spaces.GPU |
|
def _build_flash_attn(): |
|
subprocess.check_call("pip install flash-attn", shell=True) |
|
_build_flash_attn() |
|
|
|
generator = pipeline('text-generation', model='mistralai/Mistral-7B-v0.1', torch_dtype=torch.bfloat16, use_flash_attention_2=True) |
|
@spaces.GPU |
|
def generate_text(prompt, temperature, top_p, top_k, repetition_penalty, max_length): |
|
|
|
generator.model.cuda() |
|
outputs = generator( |
|
prompt, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
return_full_text=False |
|
) |
|
|
|
generated_text = outputs[0]['generated_text'] |
|
return generated_text |
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.inputs.Textbox(label="Prompt", lines=2, placeholder="Type a prompt..."), |
|
gr.inputs.Slider(minimum=0.1, maximum=2.0, step=0.01, default=0.8, label="Temperature"), |
|
gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="Top p"), |
|
gr.inputs.Slider(minimum=0, maximum=100, step=1, default=40, label="Top k"), |
|
gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.01, default=1.10, label="Repetition Penalty"), |
|
gr.inputs.Slider(minimum=5, maximum=4096, step=5, default=1024, label="Max Length") |
|
], |
|
outputs=gr.outputs.Textbox(label="Generated Text"), |
|
title="Text Completion Model", |
|
description="Try out the Mistral-7B model for free! Note this is the pretrained model and is not fine-tuned for instruction." |
|
) |
|
|
|
iface.launch() |