File size: 649 Bytes
b14fe3a
 
f90a23d
 
b14fe3a
f90a23d
b14fe3a
f90a23d
 
 
 
 
 
b14fe3a
 
 
 
 
f90a23d
b14fe3a
f90a23d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
import gradio as gr

model_name = "aws-prototyping/MegaBeam-Mistral-7B-512k"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map="auto"
)

def chat(prompt: str):
    llm = LLM(model=model_name)
    sampling = SamplingParams(temperature=0.7, max_tokens=512)
    outputs = llm.generate([prompt], sampling)
    return outputs[0].outputs[0].text

iface = gr.Interface(fn=chat, inputs="text", outputs="text")
iface.launch()