hse-python-assistant/app/models/qwen.py

95 lines
3.3 KiB
Python
Raw Normal View History

2024-10-16 22:57:05 +00:00
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import Optional
2024-10-17 07:24:52 +00:00
from app.models.base import BaseModel
2024-10-16 22:57:05 +00:00
2024-10-17 07:24:52 +00:00
class Qwen(BaseModel):
2024-10-16 22:57:05 +00:00
"""A class to handle inference with fine-tuned Qwen2.5 model."""
def __init__(
self,
model_path: str,
system_prompt: Optional[str] = None,
temperature: float = 0.6,
max_tokens: int = 2048,
) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(self.device)
model_path = os.path.expanduser(model_path)
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
quantization_config=quantization_config,
device_map="auto",
# low_cpu_mem_usage=True
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.system_prompt = system_prompt
self.temperature = temperature
self.max_tokens = max_tokens
self.messages = []
2024-10-17 21:07:19 +00:00
def ask(self, user_message: str, clear_history: bool = True, debug: bool = False) -> Optional[str]:
2024-10-16 22:57:05 +00:00
if clear_history:
self.messages = []
if self.system_prompt:
self.messages.append({"role": "system", "text": self.system_prompt})
self.messages.append({"role": "user", "text": user_message})
prompt_text = self._build_prompt()
2024-10-17 21:07:19 +00:00
if debug:
print(prompt_text)
2024-10-16 22:57:05 +00:00
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
)
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
2024-10-17 07:24:52 +00:00
lines = response_text.splitlines()
assistant_index = next(i for i, line in enumerate(lines) if "assistant" in line)
extracted_lines = lines[assistant_index + 1:]
response = "\n".join(extracted_lines)
self.messages.append({"role": "assistant", "text": response})
return response
2024-10-16 22:57:05 +00:00
def _build_prompt(self) -> str:
prompt = ""
for message in self.messages:
if message["role"] == "system":
prompt += f"<|im_start|>system\n{message['text']}<|im_end|>\n\n"
elif message["role"] == "user":
prompt += f"<|im_start|>user\n{message['text']}<|im_end|>\n\n"
2024-10-16 23:03:58 +00:00
prompt += "<|im_start|>assistant\n"
2024-10-16 22:57:05 +00:00
return prompt
if __name__ == "__main__":
2024-10-17 19:27:10 +00:00
model_path = "/home/ozakharov/hse_hackathon/Qwen2.5-32B-Instruct-hse_fine_tuned_v2"
2024-10-16 22:57:05 +00:00
2024-10-17 07:24:52 +00:00
system_prompt = "Ты - гуру хакатонов. Ты должен доходчиво, объемно и понятно объяснять пользователям их просьбы. Ты должен отвечать в веселом формате, как зумер, с эмодзи и поддерживать морально."
2024-10-16 22:57:05 +00:00
qwen = Qwen(model_path=model_path, system_prompt=system_prompt)
user_message = "Как выиграть хакатон?"
response = qwen.ask(user_message)
print(response)