Home > Technology peripherals > AI > Fast enough! The popular ChatGPT equivalent open source project is here, netizen: I'm worried that I won't be able to run it

Fast enough! The popular ChatGPT equivalent open source project is here, netizen: I'm worried that I won't be able to run it

WBOY
Release: 2023-04-12 15:19:53
forward
1431 people have browsed it

Fast enough! The popular ChatGPT equivalent open source project is here, netizen: I'm worried that I won't be able to run it

Recently, ChatGPT, an AI chatbot program developed by OpenAI, has swept across major AI communities. Everyone’s enthusiasm for it has only increased, and they continue to tap its potential.

Some researchers couldn't sit still and began to wonder how to develop an open source software equivalent to ChatGPT. For those who have not yet taken action, here is a reference example this time. The project (PaLM RLHF) we will introduce below implements such a function.

Fast enough! The popular ChatGPT equivalent open source project is here, netizen: Im worried that I wont be able to run it

Project address: https://github.com/lucidrains/PaLM-rlhf-pytorch

This project implements RLHF ( human feedback reinforcement learning). Basically the same as ChatGPT, the difference is that PaLM is used. PaLM is a large language model with 540 billion parameters trained on Google's general AI architecture "Pathways". RLHF is ChatGPT's introduction of "manually labeled data reinforcement learning" (RLHF) on the basis of the GPT 3.5 series of models to continuously fine-tune the pre-trained language model, aiming to allow the large language model (LLM) to learn to understand human commands and learn Give the optimal answer based on the given prompt.

If you want to know more about RLHF, you can refer to: https://huggingface.co/blog/rlhf

As netizens said: "In the field of AI, every time there is a special project Breakthrough, developers will soon reproduce an open source version."

Fast enough! The popular ChatGPT equivalent open source project is here, netizen: Im worried that I wont be able to run it

However, the project currently only contains Training architecture and code, no pre-trained weights. In the instructions for use, the document also shows that PaLM must be trained first.

Some netizens also expressed concern about this, saying: This is not an out-of-the-box project, it is just a structure, like Like the shell, it requires expensive overhead to train. No organization can train PaLM like Google.

Fast enough! The popular ChatGPT equivalent open source project is here, netizen: Im worried that I wont be able to run it

Some netizens said: "It is very bad not to have pre-trained weights. The official needs to release at least 50% of the sparse weights, and let developers train the rest by themselves. It’s the best choice."

Fast enough! The popular ChatGPT equivalent open source project is here, netizen: Im worried that I wont be able to run it

However, some netizens said they would try it:

Fast enough! The popular ChatGPT equivalent open source project is here, netizen: Im worried that I wont be able to run it

Let’s take a look below See how this project works.

Installation

$ pip install palm-rlhf-pytorch
Copy after login

Usage

First train PaLM, just like any other autoregressive transformer.

import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12
).cuda()
seq = torch.randint(0, 20000, (1, 2048)).cuda()
loss = palm(seq, return_loss = True)loss.backward()
# after much training, you can now generate sequences
generated = palm.generate(2048) # (1, 2048)
Copy after login

The reward model is then trained using curated human feedback. In the original paper, it was not possible to obtain a fine-tuned reward model from a pretrained transformer without overfitting. The project authors provide the option to use LoRA for fine-tuning.

import torch
from palm_rlhf_pytorch import PaLM, RewardModel
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
causal = False
)
reward_model = RewardModel(
palm,
num_binned_output = 5 # say rating from 1 to 5
).cuda()
# mock data
seq = torch.randint(0, 20000, (1, 1024)).cuda()prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()
# train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)loss.backward()
# after much training
reward = reward_model(seq, prompt_mask = prompt_mask)
Copy after login

Finally pass the transformer and reward model to RLHFTrainer.

import torch
from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer
# load your pretrained palm
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12
).cuda()
palm.load('./path/to/pretrained/palm.pt')
# load your pretrained reward model
reward_model = RewardModel(
palm,
num_binned_output = 5
).cuda()
reward_model.load('./path/to/pretrained/reward_model.pt')
# ready your list of prompts for reinforcement learning
prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts
# pass it all to the trainer and train
trainer = RLHFTrainer(
palm = palm,
reward_model = reward_model,
prompt_token_ids = prompts
)
trainer.train(num_episodes = 50000)
# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one
answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
Copy after login

The above is the detailed content of Fast enough! The popular ChatGPT equivalent open source project is here, netizen: I'm worried that I won't be able to run it. For more information, please follow other related articles on the PHP Chinese website!

Related labels:
source:51cto.com
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template