File size: 1,671 Bytes
a0a6a64
e02e941
 
a0a6a64
 
 
 
e02e941
 
 
 
a0a6a64
 
 
e02e941
16b4096
a0a6a64
 
65bcfd2
 
 
 
 
 
a0a6a64
9ba385a
a0a6a64
 
 
 
 
 
16b4096
 
 
 
 
a0a6a64
 
 
 
 
16b4096
 
a0a6a64
e02e941
a8ec507
 
 
 
16b4096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba385a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import spaces
import torch
import gradio as gr


# cpu

zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' 🤔


# gpu

@spaces.GPU
def greet(prompts, separator):
    # print(zero.device) # <-- 'cuda:0' 🤗
    from vllm import SamplingParams, LLM
    from transformers.utils import move_cache
    from huggingface_hub import snapshot_download, login

    LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2"
    fp = snapshot_download(LLM_MODEL_ID)
    move_cache()

    model = LLM(fp)
    sampling_params = dict(
            temperature = 0.3,
            ignore_eos = False,
            max_tokens = int(512 * 2)
    )
    sampling_params = SamplingParams(**sampling_params)

    multi_prompt = False
    if separator in prompts:
        multi_prompt = True
        prompts = prompts.split('separator')
    model_outputs = model.generate(prompts, sampling_params)
    generations = []
    for output in model_outputs:
        for outputs in output.outputs:
            generations.append(outputs.text)
    if multi_prompt:
        return generations
    return generations[0]


## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app

demo = gr.Interface(
    fn=greet,
    inputs=[
        gr.Text(
            value='hello sir!<SEP>bonjour madame...',
            placeholder='hello sir!<SEP>bonjour madame...',
            label='list of prompts separated by separator'
        ),
        gr.Text(
            value='<SEP>',
            placeholder='<SEP>',
            label='separator for your prompts'
        )],
    outputs=gr.Text()
)
demo.launch(share=True)