|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model_name = "gpt2" |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
def respond(message, history=None, system_message=None, max_tokens=512, temperature=0.7, top_p=0.95): |
|
if history is None: |
|
history = [] |
|
|
|
|
|
input_text = "\n".join([msg[1] for msg in history] + [message]) |
|
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
) |
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = format_response(response) |
|
|
|
|
|
history.append((message, response)) |
|
|
|
return response, history |
|
|
|
def format_response(response): |
|
|
|
diagnosis = extract_diagnosis(response) |
|
operation = extract_operation(response) |
|
treatment = extract_treatment(response) |
|
|
|
formatted_response = f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}" |
|
|
|
return formatted_response |
|
|
|
def extract_diagnosis(response): |
|
|
|
diagnosis = response.split(".")[0] |
|
return diagnosis.strip() |
|
|
|
def extract_operation(response): |
|
|
|
operation = "Не требуется" |
|
return operation.strip() |
|
|
|
def extract_treatment(response): |
|
|
|
treatment = response.split(".")[-1] |
|
return treatment.strip() |
|
|
|
|
|
demo = gr.Interface( |
|
fn=respond, |
|
inputs=[ |
|
gr.Textbox(value="Здравствуйте. Отвечай кратко...", label="System message"), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, label="Temperature"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p"), |
|
], |
|
outputs=["text", "state"], |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|