File size: 2,177 Bytes
c3dfc09 4d87bab 2edf1f0 e143294 c3dfc09 23be90a e143294 23be90a 85dd489 2edf1f0 e143294 4d87bab 23be90a 4d87bab e143294 23be90a 4d87bab e143294 2edf1f0 e143294 2edf1f0 c3dfc09 85dd489 2f225f8 e143294 4d87bab 85dd489 4d87bab 92bf9aa 85dd489 e143294 2edf1f0 e143294 85dd489 c3dfc09 133324c 2edf1f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from transformers import pipeline
import torch
import logging
# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Загружаем модель
model_name = "gpt2" # Смена на базовую модель GPT-2
try:
logger.info(f"Попытка загрузки модели {model_name}...")
generator = pipeline(
"text-generation",
model=model_name,
device=-1,
framework="pt",
max_length=512,
truncation=True,
model_kwargs={"torch_dtype": torch.float32}
)
logger.info("Модель успешно загружена.")
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
exit(1)
def respond(message, max_tokens=256, temperature=0.7, top_p=0.9):
try:
logger.info(f"Генерация ответа для: {message}")
outputs = generator(
message,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
no_repeat_ngram_size=2,
num_return_sequences=1
)
response = outputs[0]["generated_text"].strip()
logger.info(f"Ответ сгенерирован: {response}")
except Exception as e:
logger.error(f"Ошибка генерации ответа: {e}")
return f"Ошибка генерации: {e}"
return response
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."),
gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов"),
gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p")
],
outputs="text",
title="Медицинский чат-бот на базе GPT-2",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |