stash/vendor/github.com/asticode/go-astits/muxer.go
cj c1a096a1a6
Caption support (#2462)
Co-authored-by: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
2022-05-06 11:59:28 +10:00

422 lines
10 KiB
Go

package astits
import (
"bytes"
"context"
"errors"
"github.com/asticode/go-astikit"
"io"
)
const (
startPID uint16 = 0x0100
pmtStartPID uint16 = 0x1000
programNumberStart uint16 = 1
)
var (
ErrPIDNotFound = errors.New("astits: PID not found")
ErrPIDAlreadyExists = errors.New("astits: PID already exists")
ErrPCRPIDInvalid = errors.New("astits: PCR PID invalid")
)
type Muxer struct {
ctx context.Context
w io.Writer
bitsWriter *astikit.BitsWriter
packetSize int
tablesRetransmitPeriod int // period in PES packets
pm programMap // pid -> programNumber
pmt PMTData
nextPID uint16
patVersion wrappingCounter
pmtVersion wrappingCounter
patBytes bytes.Buffer
pmtBytes bytes.Buffer
buf bytes.Buffer
bufWriter *astikit.BitsWriter
esContexts map[uint16]*esContext
tablesRetransmitCounter int
}
type esContext struct {
es *PMTElementaryStream
cc wrappingCounter
}
func newEsContext(es *PMTElementaryStream) *esContext {
return &esContext{
es: es,
cc: newWrappingCounter(0b1111), // CC is 4 bits
}
}
func MuxerOptTablesRetransmitPeriod(newPeriod int) func(*Muxer) {
return func(m *Muxer) {
m.tablesRetransmitPeriod = newPeriod
}
}
// TODO MuxerOptAutodetectPCRPID selecting first video PID for each PMT, falling back to first audio, falling back to any other
func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
m := &Muxer{
ctx: ctx,
w: w,
packetSize: MpegTsPacketSize, // no 192-byte packet support yet
tablesRetransmitPeriod: 40,
pm: newProgramMap(),
pmt: PMTData{
ElementaryStreams: []*PMTElementaryStream{},
ProgramNumber: programNumberStart,
},
// table version is 5-bit field
patVersion: newWrappingCounter(0b11111),
pmtVersion: newWrappingCounter(0b11111),
esContexts: map[uint16]*esContext{},
}
m.bufWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf})
m.bitsWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: m.w})
// TODO multiple programs support
m.pm.set(pmtStartPID, programNumberStart)
for _, opt := range opts {
opt(m)
}
// to output tables at the very start
m.tablesRetransmitCounter = m.tablesRetransmitPeriod
return m
}
// if es.ElementaryPID is zero, it will be generated automatically
func (m *Muxer) AddElementaryStream(es PMTElementaryStream) error {
if es.ElementaryPID != 0 {
for _, oes := range m.pmt.ElementaryStreams {
if oes.ElementaryPID == es.ElementaryPID {
return ErrPIDAlreadyExists
}
}
} else {
es.ElementaryPID = m.nextPID
m.nextPID++
}
m.pmt.ElementaryStreams = append(m.pmt.ElementaryStreams, &es)
m.esContexts[es.ElementaryPID] = newEsContext(&es)
// invalidate pmt cache
m.pmtBytes.Reset()
return nil
}
func (m *Muxer) RemoveElementaryStream(pid uint16) error {
foundIdx := -1
for i, oes := range m.pmt.ElementaryStreams {
if oes.ElementaryPID == pid {
foundIdx = i
break
}
}
if foundIdx == -1 {
return ErrPIDNotFound
}
m.pmt.ElementaryStreams = append(m.pmt.ElementaryStreams[:foundIdx], m.pmt.ElementaryStreams[foundIdx+1:]...)
delete(m.esContexts, pid)
m.pmtBytes.Reset()
return nil
}
// SetPCRPID marks pid as one to look PCRs in
func (m *Muxer) SetPCRPID(pid uint16) {
m.pmt.PCRPID = pid
}
// WriteData writes MuxerData to TS stream
// Currently only PES packets are supported
// Be aware that after successful call WriteData will set d.AdaptationField.StuffingLength value to zero
func (m *Muxer) WriteData(d *MuxerData) (int, error) {
ctx, ok := m.esContexts[d.PID]
if !ok {
return 0, ErrPIDNotFound
}
bytesWritten := 0
forceTables := d.AdaptationField != nil &&
d.AdaptationField.RandomAccessIndicator &&
d.PID == m.pmt.PCRPID
n, err := m.retransmitTables(forceTables)
if err != nil {
return n, err
}
bytesWritten += n
payloadStart := true
writeAf := d.AdaptationField != nil
payloadBytesWritten := 0
for payloadBytesWritten < len(d.PES.Data) {
pktLen := 1 + mpegTsPacketHeaderSize // sync byte + header
pkt := Packet{
Header: &PacketHeader{
ContinuityCounter: uint8(ctx.cc.get()),
HasAdaptationField: writeAf,
HasPayload: false,
PayloadUnitStartIndicator: false,
PID: d.PID,
},
}
if writeAf {
pkt.AdaptationField = d.AdaptationField
// one byte for adaptation field length field
pktLen += 1 + int(calcPacketAdaptationFieldLength(d.AdaptationField))
writeAf = false
}
bytesAvailable := m.packetSize - pktLen
if payloadStart {
pesHeaderLengthCurrent := pesHeaderLength + int(calcPESOptionalHeaderLength(d.PES.Header.OptionalHeader))
// d.AdaptationField with pes header are too big, we don't have space to write pes header
if bytesAvailable < pesHeaderLengthCurrent {
pkt.Header.HasAdaptationField = true
if pkt.AdaptationField == nil {
pkt.AdaptationField = newStuffingAdaptationField(bytesAvailable)
} else {
pkt.AdaptationField.StuffingLength = bytesAvailable
}
} else {
pkt.Header.HasPayload = true
pkt.Header.PayloadUnitStartIndicator = true
}
} else {
pkt.Header.HasPayload = true
}
if pkt.Header.HasPayload {
m.buf.Reset()
if d.PES.Header.StreamID == 0 {
d.PES.Header.StreamID = ctx.es.StreamType.ToPESStreamID()
}
ntot, npayload, err := writePESData(
m.bufWriter,
d.PES.Header,
d.PES.Data[payloadBytesWritten:],
payloadStart,
bytesAvailable,
)
if err != nil {
return bytesWritten, err
}
payloadBytesWritten += npayload
pkt.Payload = m.buf.Bytes()
bytesAvailable -= ntot
// if we still have some space in packet, we should stuff it with adaptation field stuffing
// we can't stuff packets with 0xff at the end of a packet since it's not uncommon for PES payloads to have length unspecified
if bytesAvailable > 0 {
pkt.Header.HasAdaptationField = true
if pkt.AdaptationField == nil {
pkt.AdaptationField = newStuffingAdaptationField(bytesAvailable)
} else {
pkt.AdaptationField.StuffingLength = bytesAvailable
}
}
n, err = writePacket(m.bitsWriter, &pkt, m.packetSize)
if err != nil {
return bytesWritten, err
}
bytesWritten += n
payloadStart = false
}
}
if d.AdaptationField != nil {
d.AdaptationField.StuffingLength = 0
}
return bytesWritten, nil
}
// Writes given packet to MPEG-TS stream
// Stuffs with 0xffs if packet turns out to be shorter than target packet length
func (m *Muxer) WritePacket(p *Packet) (int, error) {
return writePacket(m.bitsWriter, p, m.packetSize)
}
func (m *Muxer) retransmitTables(force bool) (int, error) {
m.tablesRetransmitCounter++
if !force && m.tablesRetransmitCounter < m.tablesRetransmitPeriod {
return 0, nil
}
n, err := m.WriteTables()
if err != nil {
return n, err
}
m.tablesRetransmitCounter = 0
return n, nil
}
func (m *Muxer) WriteTables() (int, error) {
bytesWritten := 0
if m.patBytes.Len() != m.packetSize {
if err := m.generatePAT(); err != nil {
return bytesWritten, err
}
}
if m.pmtBytes.Len() != m.packetSize {
if err := m.generatePMT(); err != nil {
return bytesWritten, err
}
}
n, err := m.w.Write(m.patBytes.Bytes())
if err != nil {
return bytesWritten, err
}
bytesWritten += n
n, err = m.w.Write(m.pmtBytes.Bytes())
if err != nil {
return bytesWritten, err
}
bytesWritten += n
return bytesWritten, nil
}
func (m *Muxer) generatePAT() error {
d := m.pm.toPATData()
syntax := &PSISectionSyntax{
Data: &PSISectionSyntaxData{PAT: d},
Header: &PSISectionSyntaxHeader{
CurrentNextIndicator: true,
// TODO support for PAT tables longer than 1 TS packet
//LastSectionNumber: 0,
//SectionNumber: 0,
TableIDExtension: d.TransportStreamID,
VersionNumber: uint8(m.patVersion.get()),
},
}
section := PSISection{
Header: &PSISectionHeader{
SectionLength: calcPATSectionLength(d),
SectionSyntaxIndicator: true,
TableID: PSITableID(d.TransportStreamID),
},
Syntax: syntax,
}
psiData := PSIData{
Sections: []*PSISection{&section},
}
m.buf.Reset()
w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf})
if _, err := writePSIData(w, &psiData); err != nil {
return err
}
m.patBytes.Reset()
wPacket := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.patBytes})
pkt := Packet{
Header: &PacketHeader{
HasPayload: true,
PayloadUnitStartIndicator: true,
PID: PIDPAT,
},
Payload: m.buf.Bytes(),
}
if _, err := writePacket(wPacket, &pkt, m.packetSize); err != nil {
// FIXME save old PAT and rollback to it here maybe?
return err
}
return nil
}
func (m *Muxer) generatePMT() error {
hasPCRPID := false
for _, es := range m.pmt.ElementaryStreams {
if es.ElementaryPID == m.pmt.PCRPID {
hasPCRPID = true
break
}
}
if !hasPCRPID {
return ErrPCRPIDInvalid
}
syntax := &PSISectionSyntax{
Data: &PSISectionSyntaxData{PMT: &m.pmt},
Header: &PSISectionSyntaxHeader{
CurrentNextIndicator: true,
// TODO support for PMT tables longer than 1 TS packet
//LastSectionNumber: 0,
//SectionNumber: 0,
TableIDExtension: m.pmt.ProgramNumber,
VersionNumber: uint8(m.pmtVersion.get()),
},
}
section := PSISection{
Header: &PSISectionHeader{
SectionLength: calcPMTSectionLength(&m.pmt),
SectionSyntaxIndicator: true,
TableID: PSITableIDPMT,
},
Syntax: syntax,
}
psiData := PSIData{
Sections: []*PSISection{&section},
}
m.buf.Reset()
w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf})
if _, err := writePSIData(w, &psiData); err != nil {
return err
}
m.pmtBytes.Reset()
wPacket := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.pmtBytes})
pkt := Packet{
Header: &PacketHeader{
HasPayload: true,
PayloadUnitStartIndicator: true,
PID: pmtStartPID, // FIXME multiple programs support
},
Payload: m.buf.Bytes(),
}
if _, err := writePacket(wPacket, &pkt, m.packetSize); err != nil {
// FIXME save old PMT and rollback to it here maybe?
return err
}
return nil
}