|
import gradio as gr |
|
|
|
|
|
from transformers import pipeline |
|
import torch |
|
|
|
MAX_NEW_TOKENS = 250 |
|
|
|
MODEL="HuggingFaceTB/SmolLM2-135M-Instruct" |
|
|
|
|
|
TEMPERATURE = 0.6 |
|
TOP_P = 0.95 |
|
REPETITION_PENALTY = 1.2 |
|
|
|
|
|
|
|
pipe = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-1.7B-Instruct") |
|
|
|
|
|
def message_fx(message, history): |
|
if len(history) == 0: |
|
send_to_api = [{'role':'user', 'content':message}] |
|
print(send_to_api) |
|
with torch.no_grad(): |
|
response = pipe(send_to_api, |
|
do_sample=True, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
temperature=TEMPERATURE, |
|
|
|
top_p=TOP_P, |
|
repetition_penalty=REPETITION_PENALTY, |
|
|
|
)[0]['generated_text'][1]['content'] |
|
return response |
|
|
|
else: |
|
send_to_api = history + [{'role':'user', 'content':message}] |
|
print(send_to_api) |
|
with torch.no_grad(): |
|
response = pipe(send_to_api, |
|
do_sample=True, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
temperature=TEMPERATURE, |
|
|
|
top_p=TOP_P, |
|
repetition_penalty=REPETITION_PENALTY, |
|
|
|
)[0]['generated_text'][-1]['content'] |
|
return response |
|
|
|
|
|
gr.ChatInterface( |
|
fn=message_fx, |
|
type="messages" |
|
).launch() |
|
|
|
|